edugenius / app.py
aagamjtdev's picture
refactor
86036b1
import os
import json
import pickle
from typing import List, Dict, Any, Tuple
from collections import Counter
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
from tqdm import tqdm
# === GRADIO AND DEPENDENCIES ===
import gradio as gr
import fitz # PyMuPDF
from PIL import Image, ImageEnhance
import pytesseract
try:
# Attempt to import the actual CRF layer for correct Viterbi decoding
from TorchCRF import CRF
except ImportError:
# Placeholder for environments where it's not yet installed, enabling model definition
class CRF:
def __init__(self, *args, **kwargs):
pass
# Fallback to simple argmax decoding if the CRF module is missing
def viterbi_decode(self, emissions, mask):
return [list(torch.argmax(emissions[0], dim=-1).cpu().numpy())]
# ========== CONFIG (Must match Training Script) ==========
MODEL_FILE = "model_CAT.pt"
VOCAB_FILE = "vocabs_CAT.pkl"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_CHAR_LEN = 16
EMBED_DIM = 100
CHAR_EMBED_DIM = 30
CHAR_CNN_OUT = 30
BBOX_DIM = 100
HIDDEN_SIZE = 512
BBOX_NORM_CONSTANT = 1000.0
INFERENCE_CHUNK_SIZE = 256
# ========== LABELS (Must match Training Script) ==========
# Including PASSAGE for the new structuring logic
# LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-IMAGE", "I-IMAGE", "B-PASSAGE", "I-PASSAGE"]
# LABEL2IDX = {l: i for i, l in enumerate(LABELS)}
# IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-IMAGE", "I-IMAGE"]
LABEL2IDX = {l: i for i, l in enumerate(LABELS)}
IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
# =========================================================
# 1. Core Classes (Vocab, CharCNNEncoder, MCQTagger)
# =========================================================
class Vocab:
def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
self.min_freq = min_freq
self.unk_token = unk_token
self.pad_token = pad_token
self.freq = Counter()
self.itos = []
self.stoi = {}
def add_sentence(self, toks):
self.freq.update(toks)
def build(self):
items = [tok for tok, c in self.freq.items() if c >= self.min_freq]
items = [self.pad_token, self.unk_token] + sorted(items)
self.itos = items
self.stoi = {s: i for i, s in enumerate(self.itos)}
def __len__(self):
return len(self.itos)
def __getitem__(self, token: str) -> int:
return self.stoi.get(token, self.stoi[self.unk_token])
def __getstate__(self):
return {
'min_freq': self.min_freq,
'unk_token': self.unk_token,
'pad_token': self.pad_token,
'itos': self.itos,
'stoi': self.stoi,
}
def __setstate__(self, state):
self.min_freq = state['min_freq']
self.unk_token = state['unk_token']
self.pad_token = state['pad_token']
self.itos = state['itos']
self.stoi = state['stoi']
self.freq = Counter()
def load_vocabs(path: str) -> Tuple[Vocab, Vocab]:
"""Loads word and character vocabularies."""
try:
absolute_path = os.path.abspath(path)
with open(absolute_path, "rb") as f:
word_vocab, char_vocab = pickle.load(f)
if len(word_vocab) <= 2:
raise IndexError("CRITICAL: Word vocabulary size is too small.")
return word_vocab, char_vocab
except Exception as e:
raise RuntimeError(f"Error loading vocabs from {path}: {e}")
class CharCNNEncoder(nn.Module):
def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(3, 4, 5)):
super().__init__()
self.char_emb = nn.Embedding(char_vocab_size, char_emb_dim, padding_idx=0)
convs = [nn.Conv1d(char_emb_dim, out_dim, kernel_size=k) for k in kernel_sizes]
self.convs = nn.ModuleList(convs)
self.out_dim = out_dim * len(convs)
def forward(self, char_ids):
B, L, C = char_ids.size()
emb = self.char_emb(char_ids.view(B * L, C)).transpose(1, 2)
outs = [torch.max(torch.relu(conv(emb)), dim=2)[0] for conv in self.convs]
res = torch.cat(outs, dim=1)
return res.view(B, L, -1)
class MCQTagger(nn.Module):
def __init__(self, vocab_size, char_vocab_size, n_labels, bbox_dim=BBOX_DIM):
super().__init__()
self.word_emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=0)
self.char_enc = CharCNNEncoder(char_vocab_size, CHAR_EMBED_DIM, CHAR_CNN_OUT)
self.bbox_proj = nn.Linear(4, bbox_dim)
in_dim = EMBED_DIM + self.char_enc.out_dim + bbox_dim
self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=2, batch_first=True, bidirectional=True, dropout=0.3)
self.ff = nn.Linear(HIDDEN_SIZE, n_labels)
self.crf = CRF(n_labels)
self.dropout = nn.Dropout(p=0.5)
def forward_emissions(self, words, chars, bboxes, mask):
wemb = self.word_emb(words)
cenc = self.char_enc(chars)
benc = self.bbox_proj(bboxes)
enc_in = torch.cat([wemb, cenc, benc], dim=-1)
enc_in = self.dropout(enc_in)
lengths = mask.sum(dim=1).cpu()
if lengths.max().item() == 0:
B, L = enc_in.size(0), enc_in.size(1)
# Return zero tensor if batch is empty
return torch.zeros((B, L, len(LABELS)), device=enc_in.device)
packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
packed_out, _ = self.bilstm(packed_in)
padded_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
return self.ff(padded_out)
def forward(self, words, chars, bboxes, mask, labels=None, class_weights=None, alpha=0.7):
emissions = self.forward_emissions(words, chars, bboxes, mask)
return self.crf.viterbi_decode(emissions, mask=mask)
# =========================================================
# 2. PDF Processing Functions
# =========================================================
def ocr_fallback_page(page: fitz.Page, page_width: float, page_height: float) -> List[Dict[str, Any]]:
"""Renders a PyMuPDF page, runs Tesseract OCR, and tokenizes the result."""
try:
pix = page.get_pixmap(matrix=fitz.Matrix(3, 3))
if pix.n - pix.alpha > 3:
pix = fitz.Pixmap(fitz.csRGB, pix)
img_pil = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
# Preprocessing for Tesseract
img_pil = img_pil.convert('L')
img_pil = ImageEnhance.Contrast(img_pil).enhance(2.0)
img_pil = ImageEnhance.Sharpness(img_pil).enhance(2.0)
ocr_data = pytesseract.image_to_data(img_pil, output_type=pytesseract.Output.DICT)
ocr_tokens = []
for i in range(len(ocr_data['text'])):
word = ocr_data['text'][i]
conf = ocr_data['conf'][i]
if word.strip() and int(conf) > 50:
left, top, width, height = (ocr_data[k][i] for k in ['left', 'top', 'width', 'height'])
scale = page_width / pix.width
raw_bbox = [
left * scale, top * scale, (left + width) * scale, (top + height) * scale
]
normalized_bbox = [
(raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
(raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
(raw_bbox[2] / page_width) * BBOX_NORM_CONSTANT,
(raw_bbox[3] / page_height) * BBOX_NORM_CONSTANT
]
ocr_tokens.append({
"word": word,
"raw_bbox": [int(b) for b in raw_bbox],
"normalized_bbox": [int(b) for b in normalized_bbox]
})
return ocr_tokens
except Exception as e:
print(f"OCR fallback failed: {e}")
return []
def extract_tokens_from_pdf_fitz_with_ocr(pdf_path: str) -> List[Dict[str, Any]]:
"""Extracts words and bboxes using PyMuPDF text layer and falls back to OCR."""
all_tokens = []
try:
doc = fitz.open(pdf_path)
for page_num in tqdm(range(len(doc)), desc="PDF Page Processing"):
page = doc.load_page(page_num)
page_width, page_height = page.rect.width, page.rect.height
page_tokens = []
# 1. Primary Extraction: PyMuPDF's word structure
word_list = page.get_text("words", sort=True)
if word_list:
for word_data in word_list:
word = word_data[4]
raw_bbox = word_data[:4]
normalized_bbox = [
(raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
(raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
(raw_bbox[2] / page_width) * BBOX_NORM_CONSTANT,
(raw_bbox[3] / page_height) * BBOX_NORM_CONSTANT
]
page_tokens.append({
"word": word,
"raw_bbox": [int(b) for b in raw_bbox],
"normalized_bbox": [int(b) for b in normalized_bbox]
})
# 2. OCR Fallback
if not page_tokens:
print(f" (Page {page_num + 1}) No text layer found. Running OCR...")
page_tokens = ocr_fallback_page(page, page_width, page_height)
all_tokens.extend(page_tokens)
doc.close()
except Exception as e:
raise RuntimeError(f"Error opening or processing PDF with fitz/OCR: {e}")
return all_tokens
extract_tokens_from_pdf = extract_tokens_from_pdf_fitz_with_ocr
def preprocess_and_collate_tokens(all_tokens: List[Dict[str, Any]], word_vocab: Vocab, char_vocab: Vocab,
chunk_size: int) -> List[Dict[str, Any]]:
"""Chunks the token list, converts to IDs, and prepares batches for inference."""
all_batches = []
for i in range(0, len(all_tokens), chunk_size):
chunk = all_tokens[i:i + chunk_size]
if not chunk: continue
words = [t["word"] for t in chunk]
bboxes_norm = [t["normalized_bbox"] for t in chunk]
# Convert to IDs
word_ids = [word_vocab[w] for w in words]
char_ids = []
for w in words:
chs = [char_vocab[ch] for ch in w[:MAX_CHAR_LEN]]
if len(chs) < MAX_CHAR_LEN:
pad_index = char_vocab.stoi.get(char_vocab.pad_token, 0)
chs += [pad_index] * (MAX_CHAR_LEN - len(chs))
char_ids.append(chs)
# Create padded tensors (using single-sample batches)
word_pad = torch.LongTensor([word_ids]).to(DEVICE)
char_pad = torch.LongTensor([char_ids]).to(DEVICE)
# Final normalization to [0, 1] range before feeding to the model
bbox_pad = torch.FloatTensor([bboxes_norm]).to(DEVICE) / BBOX_NORM_CONSTANT
mask = torch.ones(word_pad.size(), dtype=torch.bool).to(DEVICE)
all_batches.append({
"words": word_pad,
"chars": char_pad,
"bboxes": bbox_pad,
"mask": mask,
"original_tokens": chunk
})
return all_batches
# =========================================================
# 3. Model Loading and Caching (Global Variables Defined Here!)
# =========================================================
# Global variables (MODEL, VOCABS) are defined here for use in the wrapper function
WORD_VOCAB = None
CHAR_VOCAB = None
MODEL = None
try:
WORD_VOCAB, CHAR_VOCAB = load_vocabs(VOCAB_FILE)
MODEL = MCQTagger(len(WORD_VOCAB), len(CHAR_VOCAB), len(LABELS)).to(DEVICE)
MODEL.load_state_dict(torch.load(MODEL_FILE, map_location=DEVICE))
MODEL.eval()
print("βœ… Model and Vocabs loaded successfully (Cached).")
except Exception as e:
# This prevents the app from crashing if the model files are missing on startup
print(f"❌ Initial Model/Vocab Load Failure: {e}")
print("The Gradio demo will not function until model_CAT.pt and vocabs_CAT.pkl are found.")
# =========================================================
# 4. Structuring Logic (Converts BIO to clean JSON)
# =========================================================
def finalize_passage_to_item(item, passage_buffer):
"""Adds passage text to the current item and clears the buffer."""
if passage_buffer:
passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
if item.get('passage'):
item['passage'] += ' ' + passage_text
else:
item['passage'] = passage_text
passage_buffer.clear()
return item
def convert_bio_to_structured_json_strict(predictions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Converts a list of {word, predicted_label} tokens into structured MCQ JSON format.
"""
structured_data = []
current_item = None
current_option_key = None
current_passage_buffer = []
current_text_buffer = []
first_question_started = False
last_entity_type = None
for item in predictions:
word = item['word']
label = item['predicted_label']
entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
current_text_buffer.append(word)
is_passage_label = (label == 'B-PASSAGE' or label == 'I-PASSAGE')
# --- BEFORE FIRST QUESTION/METADATA HANDLING ---
if not first_question_started and label != 'B-QUESTION' and not is_passage_label:
continue
# --- PASSAGE HANDLING (Before question start) ---
if not first_question_started and is_passage_label:
if label == 'B-PASSAGE' or (label == 'I-PASSAGE' and last_entity_type == 'PASSAGE'):
current_passage_buffer.append(word)
last_entity_type = 'PASSAGE'
continue
# --- NEW QUESTION START (B-QUESTION) ---
if label == 'B-QUESTION':
# 1. Capture leading text/passage as METADATA
if not first_question_started:
header_text = ' '.join(current_text_buffer[:-1]).strip()
if header_text or current_passage_buffer:
metadata_item = {'type': 'METADATA'}
metadata_item = finalize_passage_to_item(metadata_item, current_passage_buffer)
if header_text:
metadata_item['text'] = header_text
structured_data.append(metadata_item)
first_question_started = True
current_text_buffer = [word]
# 2. Save previous question block
elif current_item is not None:
current_item = finalize_passage_to_item(current_item, current_passage_buffer)
current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
structured_data.append(current_item)
current_text_buffer = [word]
# 3. Initialize new question
current_item = {
'type': 'MCQ',
'question': word,
'options': {},
'answer': '',
'text': ''
}
current_option_key = None
last_entity_type = 'QUESTION'
continue
# --- IF INSIDE A QUESTION BLOCK ---
if current_item is not None:
if label.startswith('B-'):
last_entity_type = entity_type
if entity_type == 'PASSAGE':
finalize_passage_to_item(current_item, current_passage_buffer)
current_passage_buffer.append(word)
elif entity_type == 'OPTION':
current_option_key = word
current_item['options'][current_option_key] = word
current_passage_buffer = []
elif entity_type == 'ANSWER':
current_item['answer'] = word
current_option_key = None
current_passage_buffer = []
elif entity_type == 'QUESTION':
current_item['question'] += f' {word}'
current_passage_buffer = []
elif label.startswith('I-'):
if entity_type == 'QUESTION' and last_entity_type == 'QUESTION':
current_item['question'] += f' {word}'
elif entity_type == 'OPTION' and last_entity_type == 'OPTION' and current_option_key is not None:
current_item['options'][current_option_key] += f' {word}'
elif entity_type == 'ANSWER' and last_entity_type == 'ANSWER':
current_item['answer'] += f' {word}'
elif entity_type == 'PASSAGE' and last_entity_type == 'PASSAGE':
current_passage_buffer.append(word)
elif label == 'O':
pass
# --- Finalize last item ---
if current_item is not None:
current_item = finalize_passage_to_item(current_item, current_passage_buffer)
current_item['text'] = re.sub(r'\s{2,}', ' ', ' '.join(current_text_buffer)).strip()
structured_data.append(current_item)
elif not structured_data and current_passage_buffer:
# Case: Only passage/metadata was present in the whole document
metadata_item = {'type': 'METADATA'}
metadata_item = finalize_passage_to_item(metadata_item, current_passage_buffer)
metadata_item['text'] = re.sub(r'\s{2,}', ' ', ' '.join(current_text_buffer)).strip()
structured_data.append(metadata_item)
# --- FINAL CLEANUP ---
for item in structured_data:
# Clean up all text fields for excessive whitespace
item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
if 'passage' in item:
item['passage'] = re.sub(r'\s{2,}', ' ', item['passage']).strip()
if not item['passage']:
del item['passage']
for field in ['question', 'answer']:
if field in item:
item[field] = re.sub(r'\s{2,}', ' ', item[field]).strip()
if 'options' in item:
for k, v in item['options'].items():
item['options'][k] = re.sub(r'\s{2,}', ' ', v).strip()
return structured_data
# =========================================================
# 5. The Gradio Inference Wrapper Function (Main Entry Point)
# =========================================================
def gradio_inference_wrapper(pdf_file: str) -> Tuple[str, List[Dict[str, Any]]]:
"""
Wraps the entire two-stage pipeline: (1) Tagging -> (2) Structuring.
"""
# Uses global variables defined in Section 3
if MODEL is None:
return "❌ ERROR: Model failed to load on startup. Check 'model_CAT.pt' and 'vocabs_CAT.pkl'.", []
pdf_path = pdf_file
raw_predictions = []
try:
# 1. Stage 1: PDF Processing and BIO Tagging
all_tokens = extract_tokens_from_pdf(pdf_path)
if not all_tokens:
return "❌ ERROR: No tokens were extracted from the PDF, even after OCR fallback.", []
# Uses global variables WORD_VOCAB, CHAR_VOCAB, INFERENCE_CHUNK_SIZE
batches = preprocess_and_collate_tokens(all_tokens, WORD_VOCAB, CHAR_VOCAB, chunk_size=INFERENCE_CHUNK_SIZE)
with torch.no_grad():
for batch in batches:
words, chars, bboxes, mask = (batch[k] for k in ["words", "chars", "bboxes", "mask"])
preds_batch = MODEL(words, chars, bboxes, mask)
predictions = preds_batch[0]
original_tokens = batch["original_tokens"]
for token_data, pred_idx in zip(original_tokens, predictions):
# Uses global variable IDX2LABEL
raw_predictions.append({
"word": token_data["word"],
"bbox": token_data["raw_bbox"],
"predicted_label": IDX2LABEL[pred_idx]
})
# 2. Stage 2: Structured JSON Conversion
structured_output = convert_bio_to_structured_json_strict(raw_predictions)
mcq_count = len([i for i in structured_output if i.get('type') == 'MCQ'])
status_message = f"βœ… Conversion complete. Found {mcq_count} MCQ items and {len(structured_output) - mcq_count} Metadata blocks."
return status_message, structured_output
except RuntimeError as e:
return f"❌ PDF Processing Error: {e}", []
except Exception as e:
return f"❌ An unexpected processing error occurred: {e}", []
# =========================================================
# 6. Define and Launch the Gradio Interface
# =========================================================
if __name__ == "__main__":
title = "MCQ Document Structure Tagger (Bi-LSTM-CRF) - Structured Output"
description = "Upload a PDF document. The system processes it in two stages: 1) BIO-Tagging for structural elements (Question, Option, Answer, Passage) and 2) Converting those tags into a clean, structured JSON list of MCQ items."
demo = gr.Interface(
fn=gradio_inference_wrapper,
# Ensure only PDF files are accepted
inputs=gr.File(label="Upload PDF Document"),
outputs=[
gr.Textbox(label="Status Message", interactive=False),
gr.JSON(label="Structured MCQ JSON Output", show_label=True)
],
title=title,
description=description,
allow_flagging="never",
concurrency_limit=2
)
demo.launch(show_error=True)