Spaces:
Sleeping
Sleeping
| # import gradio as gr | |
| # import torch | |
| # import torch.nn as nn | |
| # import pdfplumber | |
| # import json | |
| # import os | |
| # import re | |
| # from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model | |
| # from TorchCRF import CRF | |
| # # --------------------------------------------------------- | |
| # # 1. CONFIGURATION | |
| # # --------------------------------------------------------- | |
| # # Ensure this filename matches exactly what you uploaded to the Space | |
| # MODEL_FILENAME = "layoutlmv3_bilstm_crf_hybrid.pth" | |
| # BASE_MODEL_ID = "microsoft/layoutlmv3-base" | |
| # # Define your labels exactly as they were during training | |
| # LABELS = [ | |
| # "O", | |
| # "B-QUESTION", "I-QUESTION", | |
| # "B-OPTION", "I-OPTION", | |
| # "B-ANSWER", "I-ANSWER", | |
| # "B-SECTION_HEADING", "I-SECTION_HEADING", | |
| # "B-PASSAGE", "I-PASSAGE" | |
| # ] | |
| # LABEL2ID = {l: i for i, l in enumerate(LABELS)} | |
| # ID2LABEL = {i: l for l, i in LABEL2ID.items()} | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID) | |
| # # --------------------------------------------------------- | |
| # # 2. MODEL ARCHITECTURE | |
| # # --------------------------------------------------------- | |
| # # ⚠️ ACTION REQUIRED: | |
| # # Replace this class with the exact class definition of your | |
| # # NEW HYBRID MODEL. The class name and structure must match | |
| # # what was used when you saved 'layoutlmv3_nonlinear_scratch.pth'. | |
| # # --------------------------------------------------------- | |
| # # --------------------------------------------------------- | |
| # # 2. MODEL ARCHITECTURE (LayoutLMv3 + BiLSTM + CRF) | |
| # # --------------------------------------------------------- | |
| # class HybridModel(nn.Module): | |
| # def __init__(self, num_labels): | |
| # super().__init__() | |
| # self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID) | |
| # # Config for BiLSTM | |
| # hidden_size = self.layoutlm.config.hidden_size # Usually 768 | |
| # lstm_hidden_size = hidden_size // 2 # 384, so bidirectional output is 768 | |
| # # BiLSTM Layer | |
| # # input_size=768, hidden=384, bidir=True -> output_dim = 384 * 2 = 768 | |
| # self.lstm = nn.LSTM( | |
| # input_size=hidden_size, | |
| # hidden_size=lstm_hidden_size, | |
| # num_layers=1, | |
| # batch_first=True, | |
| # bidirectional=True | |
| # ) | |
| # # Dropout (Optional, check if you used this in training) | |
| # self.dropout = nn.Dropout(0.1) | |
| # # Classifier: Maps BiLSTM output (768) to Label count | |
| # self.classifier = nn.Linear(lstm_hidden_size * 2, num_labels) | |
| # # CRF Layer | |
| # self.crf = CRF(num_labels) | |
| # def forward(self, input_ids, bbox, attention_mask, labels=None): | |
| # # 1. LayoutLMv3 Base | |
| # outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) | |
| # sequence_output = outputs.last_hidden_state # [Batch, Seq, 768] | |
| # # 2. BiLSTM | |
| # # LSTM returns (output, (h_n, c_n)). We only need output. | |
| # lstm_output, _ = self.lstm(sequence_output) # [Batch, Seq, 768] | |
| # # 3. Dropout & Classifier | |
| # lstm_output = self.dropout(lstm_output) | |
| # emissions = self.classifier(lstm_output) # [Batch, Seq, Num_Labels] | |
| # # 4. CRF | |
| # if labels is not None: | |
| # # Training/Eval (Loss) | |
| # log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool()) | |
| # return -log_likelihood.mean() | |
| # else: | |
| # # Inference (Prediction Tags) | |
| # return self.crf.viterbi_decode(emissions, mask=attention_mask.bool()) | |
| # # --------------------------------------------------------- | |
| # # 3. MODEL LOADING LOGIC | |
| # # --------------------------------------------------------- | |
| # model = None | |
| # def load_model(): | |
| # global model | |
| # if model is None: | |
| # print(f"🔄 Loading model from {MODEL_FILENAME}...") | |
| # if not os.path.exists(MODEL_FILENAME): | |
| # raise FileNotFoundError(f"❌ Model file '{MODEL_FILENAME}' not found. Please upload it to the Files tab of your Space.") | |
| # # Initialize the model structure | |
| # model = HybridModel(num_labels=len(LABELS)) | |
| # # Load weights | |
| # try: | |
| # state_dict = torch.load(MODEL_FILENAME, map_location=device) | |
| # model.load_state_dict(state_dict) | |
| # except RuntimeError as e: | |
| # raise RuntimeError(f"❌ State dictionary mismatch. Ensure the 'HybridModel' class structure in app.py matches the model you trained.\nDetails: {e}") | |
| # model.to(device) | |
| # model.eval() | |
| # print("✅ Model loaded successfully.") | |
| # return model | |
| # # --------------------------------------------------------- | |
| # # 4. JSON CONVERSION LOGIC (Your Custom Logic) | |
| # # --------------------------------------------------------- | |
| # def convert_bio_to_structured_json(predictions): | |
| # structured_data = [] | |
| # current_item = None | |
| # current_option_key = None | |
| # current_passage_buffer = [] | |
| # current_text_buffer = [] | |
| # first_question_started = False | |
| # last_entity_type = None | |
| # just_finished_i_option = False | |
| # is_in_new_passage = False | |
| # def finalize_passage_to_item(item, passage_buffer): | |
| # if passage_buffer: | |
| # passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() | |
| # if item.get('passage'): item['passage'] += ' ' + passage_text | |
| # else: item['passage'] = passage_text | |
| # passage_buffer.clear() | |
| # # Flatten predictions list if strictly page-separated | |
| # flat_predictions = [] | |
| # for page in predictions: | |
| # flat_predictions.extend(page['data']) | |
| # for idx, item in enumerate(flat_predictions): | |
| # word = item['word'] | |
| # label = item['predicted_label'] | |
| # entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None | |
| # current_text_buffer.append(word) | |
| # previous_entity_type = last_entity_type | |
| # is_passage_label = (entity_type == 'PASSAGE') | |
| # if not first_question_started: | |
| # if label != 'B-QUESTION' and not is_passage_label: | |
| # just_finished_i_option = False | |
| # is_in_new_passage = False | |
| # continue | |
| # if is_passage_label: | |
| # current_passage_buffer.append(word) | |
| # last_entity_type = 'PASSAGE' | |
| # just_finished_i_option = False | |
| # is_in_new_passage = False | |
| # continue | |
| # if label == 'B-QUESTION': | |
| # if not first_question_started: | |
| # header_text = ' '.join(current_text_buffer[:-1]).strip() | |
| # if header_text or current_passage_buffer: | |
| # metadata_item = {'type': 'METADATA', 'passage': ''} | |
| # finalize_passage_to_item(metadata_item, current_passage_buffer) | |
| # if header_text: metadata_item['text'] = header_text | |
| # structured_data.append(metadata_item) | |
| # first_question_started = True | |
| # current_text_buffer = [word] | |
| # if current_item is not None: | |
| # finalize_passage_to_item(current_item, current_passage_buffer) | |
| # current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() | |
| # structured_data.append(current_item) | |
| # current_text_buffer = [word] | |
| # current_item = { | |
| # 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' | |
| # } | |
| # current_option_key = None | |
| # last_entity_type = 'QUESTION' | |
| # just_finished_i_option = False | |
| # is_in_new_passage = False | |
| # continue | |
| # if current_item is not None: | |
| # if is_in_new_passage: | |
| # if 'new_passage' not in current_item: current_item['new_passage'] = word | |
| # else: current_item['new_passage'] += f' {word}' | |
| # if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): | |
| # is_in_new_passage = False | |
| # if label.startswith(('B-', 'I-')): last_entity_type = entity_type | |
| # continue | |
| # is_in_new_passage = False | |
| # if label.startswith('B-'): | |
| # if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: | |
| # finalize_passage_to_item(current_item, current_passage_buffer) | |
| # current_passage_buffer = [] | |
| # last_entity_type = entity_type | |
| # if entity_type == 'PASSAGE': | |
| # if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| # current_item['new_passage'] = word | |
| # is_in_new_passage = True | |
| # else: current_passage_buffer.append(word) | |
| # elif entity_type == 'OPTION': | |
| # current_option_key = word | |
| # current_item['options'][current_option_key] = word | |
| # just_finished_i_option = False | |
| # elif entity_type == 'ANSWER': | |
| # current_item['answer'] = word | |
| # current_option_key = None | |
| # just_finished_i_option = False | |
| # elif entity_type == 'QUESTION': | |
| # current_item['question'] += f' {word}' | |
| # just_finished_i_option = False | |
| # elif label.startswith('I-'): | |
| # if entity_type == 'QUESTION': current_item['question'] += f' {word}' | |
| # elif entity_type == 'PASSAGE': | |
| # if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| # current_item['new_passage'] = word | |
| # is_in_new_passage = True | |
| # else: | |
| # if not current_passage_buffer: last_entity_type = 'PASSAGE' | |
| # current_passage_buffer.append(word) | |
| # elif entity_type == 'OPTION' and current_option_key is not None: | |
| # current_item['options'][current_option_key] += f' {word}' | |
| # just_finished_i_option = True | |
| # elif entity_type == 'ANSWER': current_item['answer'] += f' {word}' | |
| # just_finished_i_option = (entity_type == 'OPTION') | |
| # if current_item is not None: | |
| # finalize_passage_to_item(current_item, current_passage_buffer) | |
| # current_item['text'] = ' '.join(current_text_buffer).strip() | |
| # structured_data.append(current_item) | |
| # # Final Cleanup | |
| # for item in structured_data: | |
| # if 'text' in item: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() | |
| # if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() | |
| # return structured_data | |
| # # --------------------------------------------------------- | |
| # # 5. INFERENCE PIPELINE | |
| # # --------------------------------------------------------- | |
| # def process_pdf(pdf_file): | |
| # if pdf_file is None: | |
| # return None, "⚠️ Please upload a PDF file." | |
| # try: | |
| # active_model = load_model() | |
| # # A. Extract Text and Boxes | |
| # extracted_pages = [] | |
| # with pdfplumber.open(pdf_file.name) as pdf: | |
| # for page_idx, page in enumerate(pdf.pages): | |
| # width, height = page.width, page.height | |
| # words_data = page.extract_words() | |
| # page_tokens = [] | |
| # page_bboxes = [] | |
| # for w in words_data: | |
| # text = w['text'] | |
| # # Normalize bbox to 0-1000 scale | |
| # x0 = int((w['x0'] / width) * 1000) | |
| # top = int((w['top'] / height) * 1000) | |
| # x1 = int((w['x1'] / width) * 1000) | |
| # bottom = int((w['bottom'] / height) * 1000) | |
| # # Safety clamp | |
| # box = [max(0, min(x0, 1000)), max(0, min(top, 1000)), | |
| # max(0, min(x1, 1000)), max(0, min(bottom, 1000))] | |
| # page_tokens.append(text) | |
| # page_bboxes.append(box) | |
| # extracted_pages.append({"page_id": page_idx, "tokens": page_tokens, "bboxes": page_bboxes}) | |
| # # B. Run Inference | |
| # raw_predictions = [] | |
| # for page in extracted_pages: | |
| # tokens = page['tokens'] | |
| # bboxes = page['bboxes'] | |
| # if not tokens: continue | |
| # # Tokenize | |
| # encoding = tokenizer( | |
| # tokens, | |
| # boxes=bboxes, | |
| # return_tensors="pt", | |
| # padding="max_length", | |
| # truncation=True, | |
| # max_length=512, | |
| # return_offsets_mapping=True | |
| # ) | |
| # input_ids = encoding.input_ids.to(device) | |
| # bbox = encoding.bbox.to(device) | |
| # attention_mask = encoding.attention_mask.to(device) | |
| # # Predict | |
| # with torch.no_grad(): | |
| # # NOTE: If your hybrid model requires 'pixel_values', | |
| # # you will need to add image extraction logic above and pass it here. | |
| # preds = active_model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) | |
| # # Check if preds returns a tuple (loss, tags) or just tags | |
| # # The CRF implementation usually returns a list of lists of tags in viterbi_decode | |
| # pred_tags = preds[0] if isinstance(preds, tuple) else preds[0] | |
| # # Note: Standard CRF.viterbi_decode returns List[List[int]], so [0] gets the first batch item | |
| # # Alignment | |
| # word_ids = encoding.word_ids() | |
| # aligned_data = [] | |
| # prev_word_idx = None | |
| # for i, word_idx in enumerate(word_ids): | |
| # if word_idx is None: continue | |
| # if word_idx != prev_word_idx: | |
| # # pred_tags is likely a list of ints. | |
| # # If pred_tags[i] fails, your max_length might be cutting off tags, | |
| # # or the model output shape differs from the token length. | |
| # if i < len(pred_tags): | |
| # label_id = pred_tags[i] | |
| # label_str = ID2LABEL.get(label_id, "O") | |
| # aligned_data.append({"word": tokens[word_idx], "predicted_label": label_str}) | |
| # prev_word_idx = word_idx | |
| # raw_predictions.append({"data": aligned_data}) | |
| # # C. Convert to Structured JSON | |
| # final_json = convert_bio_to_structured_json(raw_predictions) | |
| # # Save output | |
| # output_filename = "structured_output.json" | |
| # with open(output_filename, "w", encoding="utf-8") as f: | |
| # json.dump(final_json, f, indent=2, ensure_ascii=False) | |
| # return output_filename, f"✅ Success! Processed {len(extracted_pages)} pages. Extracted {len(final_json)} items." | |
| # except Exception as e: | |
| # import traceback | |
| # return None, f"❌ Error:\n{str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
| # # --------------------------------------------------------- | |
| # # 6. GRADIO INTERFACE | |
| # # --------------------------------------------------------- | |
| # iface = gr.Interface( | |
| # fn=process_pdf, | |
| # inputs=gr.File(label="Upload PDF", file_types=[".pdf"]), | |
| # outputs=[ | |
| # gr.File(label="Download JSON Output"), | |
| # gr.Textbox(label="Status Log", lines=10) | |
| # ], | |
| # title="Hybrid Model Inference: PDF to JSON", | |
| # description="Upload a document to extract structured data using the custom Hybrid LayoutLMv3 model.", | |
| # flagging_mode="never" | |
| # ) | |
| # if __name__ == "__main__": | |
| # iface.launch() | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import pdfplumber | |
| import json | |
| import os | |
| import re | |
| from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model | |
| from TorchCRF import CRF | |
| # --------------------------------------------------------- | |
| # 1. CONFIGURATION | |
| # --------------------------------------------------------- | |
| MODEL_FILENAME = "layoutlmv3_bilstm_crf_hybrid.pth" | |
| BASE_MODEL_ID = "microsoft/layoutlmv3-base" | |
| # Labels: 11 Standard BIO tags + 2 Special tokens = 13 Total | |
| # NOTE: If your output labels look "scrambled" (e.g., Questions detected as Options), | |
| # try moving "UNK" and "PAD" to the BEGINNING of this list (indices 0 and 1). | |
| LABELS = [ | |
| "O", | |
| "B-QUESTION", "I-QUESTION", | |
| "B-OPTION", "I-OPTION", | |
| "B-ANSWER", "I-ANSWER", | |
| "B-SECTION_HEADING", "I-SECTION_HEADING", | |
| "B-PASSAGE", "I-PASSAGE", | |
| "UNK", "PAD" # Added to match the 13-label count in your weights | |
| ] | |
| LABEL2ID = {l: i for i, l in enumerate(LABELS)} | |
| ID2LABEL = {i: l for l, i in LABEL2ID.items()} | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID) | |
| # --------------------------------------------------------- | |
| # 2. MODEL ARCHITECTURE (LayoutLMv3 + BiLSTM + CRF) | |
| # --------------------------------------------------------- | |
| class HybridModel(nn.Module): | |
| def __init__(self, num_labels): | |
| super().__init__() | |
| self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID) | |
| # Structure derived from your error log: | |
| # Weight shape [1024, 768] implies hidden_size = 256 (1024/4) | |
| lstm_hidden_size = 256 | |
| self.lstm = nn.LSTM( | |
| input_size=768, # LayoutLMv3 output size | |
| hidden_size=lstm_hidden_size, | |
| num_layers=2, # Error log showed 'l1' weights, meaning 2 layers | |
| batch_first=True, | |
| bidirectional=True | |
| ) | |
| self.dropout = nn.Dropout(0.1) | |
| # Classifier input = lstm_hidden * 2 (bidirectional) = 256 * 2 = 512 | |
| # This matches your error log shape [13, 512] | |
| self.classifier = nn.Linear(lstm_hidden_size * 2, num_labels) | |
| self.crf = CRF(num_labels) | |
| def forward(self, input_ids, bbox, attention_mask, labels=None): | |
| outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) | |
| sequence_output = outputs.last_hidden_state | |
| # BiLSTM | |
| lstm_output, _ = self.lstm(sequence_output) | |
| # Classifier | |
| lstm_output = self.dropout(lstm_output) | |
| emissions = self.classifier(lstm_output) | |
| if labels is not None: | |
| # Training/Eval loss | |
| log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool()) | |
| return -log_likelihood.mean() | |
| else: | |
| # Inference prediction | |
| return self.crf.viterbi_decode(emissions, mask=attention_mask.bool()) | |
| # --------------------------------------------------------- | |
| # 3. MODEL LOADING | |
| # --------------------------------------------------------- | |
| model = None | |
| def load_model(): | |
| global model | |
| if model is None: | |
| print(f"🔄 Loading model from {MODEL_FILENAME}...") | |
| if not os.path.exists(MODEL_FILENAME): | |
| raise FileNotFoundError(f"❌ Model file '{MODEL_FILENAME}' not found.") | |
| model = HybridModel(num_labels=len(LABELS)) | |
| # Load state dictionary | |
| state_dict = torch.load(MODEL_FILENAME, map_location=device) | |
| # Try loading. If labels are wrong, this will still throw a shape error. | |
| try: | |
| model.load_state_dict(state_dict) | |
| except RuntimeError as e: | |
| raise RuntimeError(f"❌ Weight mismatch! \nYour model has {len(LABELS)} labels defined in script.\nCheck if 'LABELS' list needs reordering or resizing.\nDetailed Error: {e}") | |
| model.to(device) | |
| model.eval() | |
| print("✅ Model loaded successfully.") | |
| return model | |
| # --------------------------------------------------------- | |
| # 4. JSON CONVERSION LOGIC | |
| # --------------------------------------------------------- | |
| def convert_bio_to_structured_json(predictions): | |
| structured_data = [] | |
| current_item = None | |
| current_option_key = None | |
| current_passage_buffer = [] | |
| current_text_buffer = [] | |
| first_question_started = False | |
| last_entity_type = None | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| def finalize_passage_to_item(item, passage_buffer): | |
| if passage_buffer: | |
| passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() | |
| if item.get('passage'): item['passage'] += ' ' + passage_text | |
| else: item['passage'] = passage_text | |
| passage_buffer.clear() | |
| flat_predictions = [] | |
| for page in predictions: | |
| flat_predictions.extend(page['data']) | |
| for idx, item in enumerate(flat_predictions): | |
| word = item['word'] | |
| label = item['predicted_label'] | |
| # Clean label (remove B- / I-) | |
| entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None | |
| # Skip special tokens if they appear in prediction | |
| if label in ["UNK", "PAD", "O"]: | |
| current_text_buffer.append(word) | |
| continue | |
| current_text_buffer.append(word) | |
| previous_entity_type = last_entity_type | |
| is_passage_label = (entity_type == 'PASSAGE') | |
| if not first_question_started: | |
| if label != 'B-QUESTION' and not is_passage_label: | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if is_passage_label: | |
| current_passage_buffer.append(word) | |
| last_entity_type = 'PASSAGE' | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if label == 'B-QUESTION': | |
| if not first_question_started: | |
| header_text = ' '.join(current_text_buffer[:-1]).strip() | |
| if header_text or current_passage_buffer: | |
| metadata_item = {'type': 'METADATA', 'passage': ''} | |
| finalize_passage_to_item(metadata_item, current_passage_buffer) | |
| if header_text: metadata_item['text'] = header_text | |
| structured_data.append(metadata_item) | |
| first_question_started = True | |
| current_text_buffer = [word] | |
| if current_item is not None: | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() | |
| structured_data.append(current_item) | |
| current_text_buffer = [word] | |
| current_item = { | |
| 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' | |
| } | |
| current_option_key = None | |
| last_entity_type = 'QUESTION' | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if current_item is not None: | |
| if is_in_new_passage: | |
| if 'new_passage' not in current_item: current_item['new_passage'] = word | |
| else: current_item['new_passage'] += f' {word}' | |
| if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): | |
| is_in_new_passage = False | |
| if label.startswith(('B-', 'I-')): last_entity_type = entity_type | |
| continue | |
| is_in_new_passage = False | |
| if label.startswith('B-'): | |
| if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_passage_buffer = [] | |
| last_entity_type = entity_type | |
| if entity_type == 'PASSAGE': | |
| if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| current_item['new_passage'] = word | |
| is_in_new_passage = True | |
| else: current_passage_buffer.append(word) | |
| elif entity_type == 'OPTION': | |
| current_option_key = word | |
| current_item['options'][current_option_key] = word | |
| just_finished_i_option = False | |
| elif entity_type == 'ANSWER': | |
| current_item['answer'] = word | |
| current_option_key = None | |
| just_finished_i_option = False | |
| elif entity_type == 'QUESTION': | |
| current_item['question'] += f' {word}' | |
| just_finished_i_option = False | |
| elif label.startswith('I-'): | |
| if entity_type == 'QUESTION': current_item['question'] += f' {word}' | |
| elif entity_type == 'PASSAGE': | |
| if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| current_item['new_passage'] = word | |
| is_in_new_passage = True | |
| else: | |
| if not current_passage_buffer: last_entity_type = 'PASSAGE' | |
| current_passage_buffer.append(word) | |
| elif entity_type == 'OPTION' and current_option_key is not None: | |
| current_item['options'][current_option_key] += f' {word}' | |
| just_finished_i_option = True | |
| elif entity_type == 'ANSWER': current_item['answer'] += f' {word}' | |
| just_finished_i_option = (entity_type == 'OPTION') | |
| if current_item is not None: | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_item['text'] = ' '.join(current_text_buffer).strip() | |
| structured_data.append(current_item) | |
| for item in structured_data: | |
| if 'text' in item: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() | |
| if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() | |
| return structured_data | |
| # --------------------------------------------------------- | |
| # 5. PROCESSING PIPELINE | |
| # --------------------------------------------------------- | |
| def process_pdf(pdf_file): | |
| if pdf_file is None: | |
| return None, "⚠️ Please upload a PDF file." | |
| try: | |
| active_model = load_model() | |
| extracted_pages = [] | |
| with pdfplumber.open(pdf_file.name) as pdf: | |
| for page_idx, page in enumerate(pdf.pages): | |
| width, height = page.width, page.height | |
| words_data = page.extract_words() | |
| page_tokens = [] | |
| page_bboxes = [] | |
| for w in words_data: | |
| text = w['text'] | |
| x0 = int((w['x0'] / width) * 1000) | |
| top = int((w['top'] / height) * 1000) | |
| x1 = int((w['x1'] / width) * 1000) | |
| bottom = int((w['bottom'] / height) * 1000) | |
| box = [max(0, min(x0, 1000)), max(0, min(top, 1000)), | |
| max(0, min(x1, 1000)), max(0, min(bottom, 1000))] | |
| page_tokens.append(text) | |
| page_bboxes.append(box) | |
| extracted_pages.append({"page_id": page_idx, "tokens": page_tokens, "bboxes": page_bboxes}) | |
| raw_predictions = [] | |
| for page in extracted_pages: | |
| tokens = page['tokens'] | |
| bboxes = page['bboxes'] | |
| if not tokens: continue | |
| encoding = tokenizer( | |
| tokens, | |
| boxes=bboxes, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| return_offsets_mapping=True | |
| ) | |
| input_ids = encoding.input_ids.to(device) | |
| bbox = encoding.bbox.to(device) | |
| attention_mask = encoding.attention_mask.to(device) | |
| with torch.no_grad(): | |
| # Get the tag indices from the CRF layer | |
| pred_tags = active_model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) | |
| # If batch size is 1, pred_tags is a list of lists: [[tags...]] | |
| pred_tags = pred_tags[0] | |
| word_ids = encoding.word_ids() | |
| aligned_data = [] | |
| prev_word_idx = None | |
| for i, word_idx in enumerate(word_ids): | |
| if word_idx is None: continue | |
| if word_idx != prev_word_idx: | |
| if i < len(pred_tags): | |
| label_id = pred_tags[i] | |
| # Safe retrieval of label string | |
| label_str = ID2LABEL.get(label_id, "O") | |
| aligned_data.append({"word": tokens[word_idx], "predicted_label": label_str}) | |
| prev_word_idx = word_idx | |
| raw_predictions.append({"data": aligned_data}) | |
| final_json = convert_bio_to_structured_json(raw_predictions) | |
| output_filename = "structured_output.json" | |
| with open(output_filename, "w", encoding="utf-8") as f: | |
| json.dump(final_json, f, indent=2, ensure_ascii=False) | |
| return output_filename, f"✅ Success! Processed {len(extracted_pages)} pages. Extracted {len(final_json)} items." | |
| except Exception as e: | |
| import traceback | |
| return None, f"❌ Error:\n{str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
| # --------------------------------------------------------- | |
| # 6. GRADIO INTERFACE | |
| # --------------------------------------------------------- | |
| iface = gr.Interface( | |
| fn=process_pdf, | |
| inputs=gr.File(label="Upload PDF", file_types=[".pdf"]), | |
| outputs=[ | |
| gr.File(label="Download JSON Output"), | |
| gr.Textbox(label="Status Log", lines=10) | |
| ], | |
| title="LayoutLMv3 + BiLSTM Hybrid Model Inference", | |
| description="Upload a document to extract structured data using the custom Hybrid LayoutLMv3 model.", | |
| flagging_mode="never" | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |