import os import json import pickle from typing import List, Dict, Any, Tuple from collections import Counter import torch import torch.nn as nn import torch.nn.functional as F import re from tqdm import tqdm # === GRADIO AND DEPENDENCIES === import gradio as gr import fitz # PyMuPDF from PIL import Image, ImageEnhance import pytesseract try: # Attempt to import the actual CRF layer for correct Viterbi decoding from TorchCRF import CRF except ImportError: # Placeholder for environments where it's not yet installed, enabling model definition class CRF: def __init__(self, *args, **kwargs): pass # Fallback to simple argmax decoding if the CRF module is missing def viterbi_decode(self, emissions, mask): return [list(torch.argmax(emissions[0], dim=-1).cpu().numpy())] # ========== CONFIG (Must match Training Script) ========== MODEL_FILE = "model_CAT.pt" VOCAB_FILE = "vocabs_CAT.pkl" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MAX_CHAR_LEN = 16 EMBED_DIM = 100 CHAR_EMBED_DIM = 30 CHAR_CNN_OUT = 30 BBOX_DIM = 100 HIDDEN_SIZE = 512 BBOX_NORM_CONSTANT = 1000.0 INFERENCE_CHUNK_SIZE = 256 # ========== LABELS (Must match Training Script) ========== # Including PASSAGE for the new structuring logic # LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-IMAGE", "I-IMAGE", "B-PASSAGE", "I-PASSAGE"] # LABEL2IDX = {l: i for i, l in enumerate(LABELS)} # IDX2LABEL = {i: l for i, l in enumerate(LABELS)} LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-IMAGE", "I-IMAGE"] LABEL2IDX = {l: i for i, l in enumerate(LABELS)} IDX2LABEL = {i: l for i, l in enumerate(LABELS)} # ========================================================= # 1. Core Classes (Vocab, CharCNNEncoder, MCQTagger) # ========================================================= class Vocab: def __init__(self, min_freq=1, unk_token="", pad_token=""): self.min_freq = min_freq self.unk_token = unk_token self.pad_token = pad_token self.freq = Counter() self.itos = [] self.stoi = {} def add_sentence(self, toks): self.freq.update(toks) def build(self): items = [tok for tok, c in self.freq.items() if c >= self.min_freq] items = [self.pad_token, self.unk_token] + sorted(items) self.itos = items self.stoi = {s: i for i, s in enumerate(self.itos)} def __len__(self): return len(self.itos) def __getitem__(self, token: str) -> int: return self.stoi.get(token, self.stoi[self.unk_token]) def __getstate__(self): return { 'min_freq': self.min_freq, 'unk_token': self.unk_token, 'pad_token': self.pad_token, 'itos': self.itos, 'stoi': self.stoi, } def __setstate__(self, state): self.min_freq = state['min_freq'] self.unk_token = state['unk_token'] self.pad_token = state['pad_token'] self.itos = state['itos'] self.stoi = state['stoi'] self.freq = Counter() def load_vocabs(path: str) -> Tuple[Vocab, Vocab]: """Loads word and character vocabularies.""" try: absolute_path = os.path.abspath(path) with open(absolute_path, "rb") as f: word_vocab, char_vocab = pickle.load(f) if len(word_vocab) <= 2: raise IndexError("CRITICAL: Word vocabulary size is too small.") return word_vocab, char_vocab except Exception as e: raise RuntimeError(f"Error loading vocabs from {path}: {e}") class CharCNNEncoder(nn.Module): def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(3, 4, 5)): super().__init__() self.char_emb = nn.Embedding(char_vocab_size, char_emb_dim, padding_idx=0) convs = [nn.Conv1d(char_emb_dim, out_dim, kernel_size=k) for k in kernel_sizes] self.convs = nn.ModuleList(convs) self.out_dim = out_dim * len(convs) def forward(self, char_ids): B, L, C = char_ids.size() emb = self.char_emb(char_ids.view(B * L, C)).transpose(1, 2) outs = [torch.max(torch.relu(conv(emb)), dim=2)[0] for conv in self.convs] res = torch.cat(outs, dim=1) return res.view(B, L, -1) class MCQTagger(nn.Module): def __init__(self, vocab_size, char_vocab_size, n_labels, bbox_dim=BBOX_DIM): super().__init__() self.word_emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=0) self.char_enc = CharCNNEncoder(char_vocab_size, CHAR_EMBED_DIM, CHAR_CNN_OUT) self.bbox_proj = nn.Linear(4, bbox_dim) in_dim = EMBED_DIM + self.char_enc.out_dim + bbox_dim self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=2, batch_first=True, bidirectional=True, dropout=0.3) self.ff = nn.Linear(HIDDEN_SIZE, n_labels) self.crf = CRF(n_labels) self.dropout = nn.Dropout(p=0.5) def forward_emissions(self, words, chars, bboxes, mask): wemb = self.word_emb(words) cenc = self.char_enc(chars) benc = self.bbox_proj(bboxes) enc_in = torch.cat([wemb, cenc, benc], dim=-1) enc_in = self.dropout(enc_in) lengths = mask.sum(dim=1).cpu() if lengths.max().item() == 0: B, L = enc_in.size(0), enc_in.size(1) # Return zero tensor if batch is empty return torch.zeros((B, L, len(LABELS)), device=enc_in.device) packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False) packed_out, _ = self.bilstm(packed_in) padded_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) return self.ff(padded_out) def forward(self, words, chars, bboxes, mask, labels=None, class_weights=None, alpha=0.7): emissions = self.forward_emissions(words, chars, bboxes, mask) return self.crf.viterbi_decode(emissions, mask=mask) # ========================================================= # 2. PDF Processing Functions # ========================================================= def ocr_fallback_page(page: fitz.Page, page_width: float, page_height: float) -> List[Dict[str, Any]]: """Renders a PyMuPDF page, runs Tesseract OCR, and tokenizes the result.""" try: pix = page.get_pixmap(matrix=fitz.Matrix(3, 3)) if pix.n - pix.alpha > 3: pix = fitz.Pixmap(fitz.csRGB, pix) img_pil = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) # Preprocessing for Tesseract img_pil = img_pil.convert('L') img_pil = ImageEnhance.Contrast(img_pil).enhance(2.0) img_pil = ImageEnhance.Sharpness(img_pil).enhance(2.0) ocr_data = pytesseract.image_to_data(img_pil, output_type=pytesseract.Output.DICT) ocr_tokens = [] for i in range(len(ocr_data['text'])): word = ocr_data['text'][i] conf = ocr_data['conf'][i] if word.strip() and int(conf) > 50: left, top, width, height = (ocr_data[k][i] for k in ['left', 'top', 'width', 'height']) scale = page_width / pix.width raw_bbox = [ left * scale, top * scale, (left + width) * scale, (top + height) * scale ] normalized_bbox = [ (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT, (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT, (raw_bbox[2] / page_width) * BBOX_NORM_CONSTANT, (raw_bbox[3] / page_height) * BBOX_NORM_CONSTANT ] ocr_tokens.append({ "word": word, "raw_bbox": [int(b) for b in raw_bbox], "normalized_bbox": [int(b) for b in normalized_bbox] }) return ocr_tokens except Exception as e: print(f"OCR fallback failed: {e}") return [] def extract_tokens_from_pdf_fitz_with_ocr(pdf_path: str) -> List[Dict[str, Any]]: """Extracts words and bboxes using PyMuPDF text layer and falls back to OCR.""" all_tokens = [] try: doc = fitz.open(pdf_path) for page_num in tqdm(range(len(doc)), desc="PDF Page Processing"): page = doc.load_page(page_num) page_width, page_height = page.rect.width, page.rect.height page_tokens = [] # 1. Primary Extraction: PyMuPDF's word structure word_list = page.get_text("words", sort=True) if word_list: for word_data in word_list: word = word_data[4] raw_bbox = word_data[:4] normalized_bbox = [ (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT, (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT, (raw_bbox[2] / page_width) * BBOX_NORM_CONSTANT, (raw_bbox[3] / page_height) * BBOX_NORM_CONSTANT ] page_tokens.append({ "word": word, "raw_bbox": [int(b) for b in raw_bbox], "normalized_bbox": [int(b) for b in normalized_bbox] }) # 2. OCR Fallback if not page_tokens: print(f" (Page {page_num + 1}) No text layer found. Running OCR...") page_tokens = ocr_fallback_page(page, page_width, page_height) all_tokens.extend(page_tokens) doc.close() except Exception as e: raise RuntimeError(f"Error opening or processing PDF with fitz/OCR: {e}") return all_tokens extract_tokens_from_pdf = extract_tokens_from_pdf_fitz_with_ocr def preprocess_and_collate_tokens(all_tokens: List[Dict[str, Any]], word_vocab: Vocab, char_vocab: Vocab, chunk_size: int) -> List[Dict[str, Any]]: """Chunks the token list, converts to IDs, and prepares batches for inference.""" all_batches = [] for i in range(0, len(all_tokens), chunk_size): chunk = all_tokens[i:i + chunk_size] if not chunk: continue words = [t["word"] for t in chunk] bboxes_norm = [t["normalized_bbox"] for t in chunk] # Convert to IDs word_ids = [word_vocab[w] for w in words] char_ids = [] for w in words: chs = [char_vocab[ch] for ch in w[:MAX_CHAR_LEN]] if len(chs) < MAX_CHAR_LEN: pad_index = char_vocab.stoi.get(char_vocab.pad_token, 0) chs += [pad_index] * (MAX_CHAR_LEN - len(chs)) char_ids.append(chs) # Create padded tensors (using single-sample batches) word_pad = torch.LongTensor([word_ids]).to(DEVICE) char_pad = torch.LongTensor([char_ids]).to(DEVICE) # Final normalization to [0, 1] range before feeding to the model bbox_pad = torch.FloatTensor([bboxes_norm]).to(DEVICE) / BBOX_NORM_CONSTANT mask = torch.ones(word_pad.size(), dtype=torch.bool).to(DEVICE) all_batches.append({ "words": word_pad, "chars": char_pad, "bboxes": bbox_pad, "mask": mask, "original_tokens": chunk }) return all_batches # ========================================================= # 3. Model Loading and Caching (Global Variables Defined Here!) # ========================================================= # Global variables (MODEL, VOCABS) are defined here for use in the wrapper function WORD_VOCAB = None CHAR_VOCAB = None MODEL = None try: WORD_VOCAB, CHAR_VOCAB = load_vocabs(VOCAB_FILE) MODEL = MCQTagger(len(WORD_VOCAB), len(CHAR_VOCAB), len(LABELS)).to(DEVICE) MODEL.load_state_dict(torch.load(MODEL_FILE, map_location=DEVICE)) MODEL.eval() print("✅ Model and Vocabs loaded successfully (Cached).") except Exception as e: # This prevents the app from crashing if the model files are missing on startup print(f"❌ Initial Model/Vocab Load Failure: {e}") print("The Gradio demo will not function until model_CAT.pt and vocabs_CAT.pkl are found.") # ========================================================= # 4. Structuring Logic (Converts BIO to clean JSON) # ========================================================= def finalize_passage_to_item(item, passage_buffer): """Adds passage text to the current item and clears the buffer.""" if passage_buffer: passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() if item.get('passage'): item['passage'] += ' ' + passage_text else: item['passage'] = passage_text passage_buffer.clear() return item def convert_bio_to_structured_json_strict(predictions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Converts a list of {word, predicted_label} tokens into structured MCQ JSON format. """ structured_data = [] current_item = None current_option_key = None current_passage_buffer = [] current_text_buffer = [] first_question_started = False last_entity_type = None for item in predictions: word = item['word'] label = item['predicted_label'] entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None current_text_buffer.append(word) is_passage_label = (label == 'B-PASSAGE' or label == 'I-PASSAGE') # --- BEFORE FIRST QUESTION/METADATA HANDLING --- if not first_question_started and label != 'B-QUESTION' and not is_passage_label: continue # --- PASSAGE HANDLING (Before question start) --- if not first_question_started and is_passage_label: if label == 'B-PASSAGE' or (label == 'I-PASSAGE' and last_entity_type == 'PASSAGE'): current_passage_buffer.append(word) last_entity_type = 'PASSAGE' continue # --- NEW QUESTION START (B-QUESTION) --- if label == 'B-QUESTION': # 1. Capture leading text/passage as METADATA if not first_question_started: header_text = ' '.join(current_text_buffer[:-1]).strip() if header_text or current_passage_buffer: metadata_item = {'type': 'METADATA'} metadata_item = finalize_passage_to_item(metadata_item, current_passage_buffer) if header_text: metadata_item['text'] = header_text structured_data.append(metadata_item) first_question_started = True current_text_buffer = [word] # 2. Save previous question block elif current_item is not None: current_item = finalize_passage_to_item(current_item, current_passage_buffer) current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() structured_data.append(current_item) current_text_buffer = [word] # 3. Initialize new question current_item = { 'type': 'MCQ', 'question': word, 'options': {}, 'answer': '', 'text': '' } current_option_key = None last_entity_type = 'QUESTION' continue # --- IF INSIDE A QUESTION BLOCK --- if current_item is not None: if label.startswith('B-'): last_entity_type = entity_type if entity_type == 'PASSAGE': finalize_passage_to_item(current_item, current_passage_buffer) current_passage_buffer.append(word) elif entity_type == 'OPTION': current_option_key = word current_item['options'][current_option_key] = word current_passage_buffer = [] elif entity_type == 'ANSWER': current_item['answer'] = word current_option_key = None current_passage_buffer = [] elif entity_type == 'QUESTION': current_item['question'] += f' {word}' current_passage_buffer = [] elif label.startswith('I-'): if entity_type == 'QUESTION' and last_entity_type == 'QUESTION': current_item['question'] += f' {word}' elif entity_type == 'OPTION' and last_entity_type == 'OPTION' and current_option_key is not None: current_item['options'][current_option_key] += f' {word}' elif entity_type == 'ANSWER' and last_entity_type == 'ANSWER': current_item['answer'] += f' {word}' elif entity_type == 'PASSAGE' and last_entity_type == 'PASSAGE': current_passage_buffer.append(word) elif label == 'O': pass # --- Finalize last item --- if current_item is not None: current_item = finalize_passage_to_item(current_item, current_passage_buffer) current_item['text'] = re.sub(r'\s{2,}', ' ', ' '.join(current_text_buffer)).strip() structured_data.append(current_item) elif not structured_data and current_passage_buffer: # Case: Only passage/metadata was present in the whole document metadata_item = {'type': 'METADATA'} metadata_item = finalize_passage_to_item(metadata_item, current_passage_buffer) metadata_item['text'] = re.sub(r'\s{2,}', ' ', ' '.join(current_text_buffer)).strip() structured_data.append(metadata_item) # --- FINAL CLEANUP --- for item in structured_data: # Clean up all text fields for excessive whitespace item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() if 'passage' in item: item['passage'] = re.sub(r'\s{2,}', ' ', item['passage']).strip() if not item['passage']: del item['passage'] for field in ['question', 'answer']: if field in item: item[field] = re.sub(r'\s{2,}', ' ', item[field]).strip() if 'options' in item: for k, v in item['options'].items(): item['options'][k] = re.sub(r'\s{2,}', ' ', v).strip() return structured_data # ========================================================= # 5. The Gradio Inference Wrapper Function (Main Entry Point) # ========================================================= def gradio_inference_wrapper(pdf_file: str) -> Tuple[str, List[Dict[str, Any]]]: """ Wraps the entire two-stage pipeline: (1) Tagging -> (2) Structuring. """ # Uses global variables defined in Section 3 if MODEL is None: return "❌ ERROR: Model failed to load on startup. Check 'model_CAT.pt' and 'vocabs_CAT.pkl'.", [] pdf_path = pdf_file raw_predictions = [] try: # 1. Stage 1: PDF Processing and BIO Tagging all_tokens = extract_tokens_from_pdf(pdf_path) if not all_tokens: return "❌ ERROR: No tokens were extracted from the PDF, even after OCR fallback.", [] # Uses global variables WORD_VOCAB, CHAR_VOCAB, INFERENCE_CHUNK_SIZE batches = preprocess_and_collate_tokens(all_tokens, WORD_VOCAB, CHAR_VOCAB, chunk_size=INFERENCE_CHUNK_SIZE) with torch.no_grad(): for batch in batches: words, chars, bboxes, mask = (batch[k] for k in ["words", "chars", "bboxes", "mask"]) preds_batch = MODEL(words, chars, bboxes, mask) predictions = preds_batch[0] original_tokens = batch["original_tokens"] for token_data, pred_idx in zip(original_tokens, predictions): # Uses global variable IDX2LABEL raw_predictions.append({ "word": token_data["word"], "bbox": token_data["raw_bbox"], "predicted_label": IDX2LABEL[pred_idx] }) # 2. Stage 2: Structured JSON Conversion structured_output = convert_bio_to_structured_json_strict(raw_predictions) mcq_count = len([i for i in structured_output if i.get('type') == 'MCQ']) status_message = f"✅ Conversion complete. Found {mcq_count} MCQ items and {len(structured_output) - mcq_count} Metadata blocks." return status_message, structured_output except RuntimeError as e: return f"❌ PDF Processing Error: {e}", [] except Exception as e: return f"❌ An unexpected processing error occurred: {e}", [] # ========================================================= # 6. Define and Launch the Gradio Interface # ========================================================= if __name__ == "__main__": title = "MCQ Document Structure Tagger (Bi-LSTM-CRF) - Structured Output" description = "Upload a PDF document. The system processes it in two stages: 1) BIO-Tagging for structural elements (Question, Option, Answer, Passage) and 2) Converting those tags into a clean, structured JSON list of MCQ items." demo = gr.Interface( fn=gradio_inference_wrapper, # Ensure only PDF files are accepted inputs=gr.File(label="Upload PDF Document"), outputs=[ gr.Textbox(label="Status Message", interactive=False), gr.JSON(label="Structured MCQ JSON Output", show_label=True) ], title=title, description=description, allow_flagging="never", concurrency_limit=2 ) demo.launch(show_error=True)