aagamjtdev commited on
Commit
0dc2968
·
1 Parent(s): 9c882b5

Initial deployment with LFS-tracked model

Browse files
Files changed (4) hide show
  1. app.py +435 -0
  2. model_CAT.pt +3 -0
  3. requirements.txt +6 -0
  4. vocabs_CAT.pkl +3 -0
app.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pickle
4
+ from typing import List, Dict, Any, Tuple
5
+ from collections import Counter
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from tqdm import tqdm
10
+
11
+ # === GRADIO AND DEPENDENCIES ===
12
+ import gradio as gr
13
+ import fitz # PyMuPDF
14
+ import re
15
+ from PIL import Image, ImageEnhance
16
+ import pytesseract
17
+
18
+ try:
19
+ from TorchCRF import CRF
20
+ except ImportError:
21
+ # This should be handled in requirements.txt for the Space
22
+ print("CRF module not found. Assuming deployment environment will install it.")
23
+
24
+
25
+ class CRF:
26
+ def __init__(self, *args, **kwargs): pass
27
+
28
+ def viterbi_decode(self, emissions, mask): return [list(torch.argmax(emissions[0], dim=-1).cpu().numpy())]
29
+
30
+ # ========== CONFIG (Must match Training Script) ==========
31
+ # NOTE: In a Space, we typically don't use DATA_DIR paths if the files are alongside app.py
32
+ MODEL_FILE = "model_CAT.pt"
33
+ VOCAB_FILE = "vocabs_CAT.pkl"
34
+
35
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ MAX_CHAR_LEN = 16
37
+ EMBED_DIM = 100
38
+ CHAR_EMBED_DIM = 30
39
+ CHAR_CNN_OUT = 30
40
+ BBOX_DIM = 100
41
+ HIDDEN_SIZE = 512
42
+ BBOX_NORM_CONSTANT = 1000.0
43
+ INFERENCE_CHUNK_SIZE = 256
44
+
45
+ # ========== LABELS (Must match Training Script) ==========
46
+ LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-IMAGE", "I-IMAGE"]
47
+ LABEL2IDX = {l: i for i, l in enumerate(LABELS)}
48
+ IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
49
+
50
+
51
+ # =========================================================
52
+ # 1. Vocab, CharCNNEncoder, and MCQTagger Classes (Copied from your script)
53
+ # =========================================================
54
+
55
+ class Vocab:
56
+ # ... (Your Vocab class implementation)
57
+ def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
58
+ self.min_freq = min_freq
59
+ self.unk_token = unk_token
60
+ self.pad_token = pad_token
61
+ self.freq = Counter()
62
+ self.itos = [] # Index to String
63
+ self.stoi = {} # String to Index
64
+
65
+ def add_sentence(self, toks):
66
+ self.freq.update(toks)
67
+
68
+ def build(self):
69
+ items = [tok for tok, c in self.freq.items() if c >= self.min_freq]
70
+ items = [self.pad_token, self.unk_token] + sorted(items)
71
+ self.itos = items
72
+ self.stoi = {s: i for i, s in enumerate(self.itos)}
73
+
74
+ def __len__(self):
75
+ return len(self.itos)
76
+
77
+ def __getitem__(self, token: str) -> int:
78
+ """Allows lookup using word_vocab[token]. Returns UNK index if token is not found."""
79
+ return self.stoi.get(token, self.stoi[self.unk_token])
80
+
81
+ def __getstate__(self):
82
+ return {
83
+ 'min_freq': self.min_freq,
84
+ 'unk_token': self.unk_token,
85
+ 'pad_token': self.pad_token,
86
+ 'itos': self.itos,
87
+ 'stoi': self.stoi,
88
+ }
89
+
90
+ def __setstate__(self, state):
91
+ self.min_freq = state['min_freq']
92
+ self.unk_token = state['unk_token']
93
+ self.pad_token = state['pad_token']
94
+ self.itos = state['itos']
95
+ self.stoi = state['stoi']
96
+ self.freq = Counter()
97
+
98
+
99
+ def load_vocabs(path: str) -> Tuple[Vocab, Vocab]:
100
+ """Loads word and character vocabularies from a pickle file and verifies size."""
101
+ try:
102
+ absolute_path = os.path.abspath(path)
103
+ if not os.path.exists(absolute_path):
104
+ raise FileNotFoundError(f"Vocab file NOT FOUND at: {absolute_path}")
105
+ with open(absolute_path, "rb") as f:
106
+ word_vocab, char_vocab = pickle.load(f)
107
+ if len(word_vocab) <= 2:
108
+ raise IndexError("CRITICAL: Word vocabulary size is too small. Vocab file is invalid.")
109
+ return word_vocab, char_vocab
110
+ except FileNotFoundError:
111
+ raise FileNotFoundError(f"Vocab file not found at {path}. Please run the training script first.")
112
+ except Exception as e:
113
+ raise RuntimeError(f"Error loading vocabs from {path}: {e}")
114
+
115
+
116
+ class CharCNNEncoder(nn.Module):
117
+ def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(3, 4, 5)):
118
+ super().__init__()
119
+ self.char_emb = nn.Embedding(char_vocab_size, char_emb_dim, padding_idx=0)
120
+ convs = [nn.Conv1d(char_emb_dim, out_dim, kernel_size=k) for k in kernel_sizes]
121
+ self.convs = nn.ModuleList(convs)
122
+ self.out_dim = out_dim * len(convs)
123
+
124
+ def forward(self, char_ids):
125
+ B, L, C = char_ids.size()
126
+ emb = self.char_emb(char_ids.view(B * L, C)).transpose(1, 2)
127
+ outs = [torch.max(torch.relu(conv(emb)), dim=2)[0] for conv in self.convs]
128
+ res = torch.cat(outs, dim=1)
129
+ return res.view(B, L, -1)
130
+
131
+
132
+ class MCQTagger(nn.Module):
133
+ def __init__(self, vocab_size, char_vocab_size, n_labels, bbox_dim=BBOX_DIM):
134
+ super().__init__()
135
+ self.word_emb = nn.Embedding(vocab_size, EMBED_DIM, padding_idx=0)
136
+ self.char_enc = CharCNNEncoder(char_vocab_size, CHAR_EMBED_DIM, CHAR_CNN_OUT)
137
+ self.bbox_proj = nn.Linear(4, bbox_dim)
138
+ in_dim = EMBED_DIM + self.char_enc.out_dim + bbox_dim
139
+
140
+ self.bilstm = nn.LSTM(in_dim, HIDDEN_SIZE // 2, num_layers=2, batch_first=True, bidirectional=True, dropout=0.3)
141
+ self.ff = nn.Linear(HIDDEN_SIZE, n_labels)
142
+ self.crf = CRF(n_labels)
143
+ self.dropout = nn.Dropout(p=0.5)
144
+
145
+ def forward_emissions(self, words, chars, bboxes, mask):
146
+ wemb = self.word_emb(words)
147
+ cenc = self.char_enc(chars)
148
+ benc = self.bbox_proj(bboxes)
149
+ enc_in = torch.cat([wemb, cenc, benc], dim=-1)
150
+ enc_in = self.dropout(enc_in)
151
+ lengths = mask.sum(dim=1).cpu()
152
+
153
+ if lengths.max().item() == 0:
154
+ B, L = enc_in.size(0), enc_in.size(1)
155
+ return torch.zeros((B, L, len(LABELS)), device=enc_in.device)
156
+
157
+ packed_in = nn.utils.rnn.pack_padded_sequence(enc_in, lengths, batch_first=True, enforce_sorted=False)
158
+ packed_out, _ = self.bilstm(packed_in)
159
+ padded_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
160
+
161
+ return self.ff(padded_out)
162
+
163
+ def forward(self, words, chars, bboxes, mask, labels=None, class_weights=None, alpha=0.7):
164
+ emissions = self.forward_emissions(words, chars, bboxes, mask)
165
+ # We only decode for inference, not calculate loss
166
+ return self.crf.viterbi_decode(emissions, mask=mask)
167
+
168
+
169
+ # =========================================================
170
+ # 2. PDF Processing Functions (Copied from your script)
171
+ # =========================================================
172
+
173
+ def ocr_fallback_page(page: fitz.Page, page_width: float, page_height: float) -> List[Dict[str, Any]]:
174
+ # ... (Your ocr_fallback_page implementation)
175
+ """
176
+ Renders a PyMuPDF page, runs Tesseract OCR, and tokenizes the result.
177
+ """
178
+ try:
179
+ # Render page at high resolution (300 DPI equivalent)
180
+ pix = page.get_pixmap(matrix=fitz.Matrix(3, 3))
181
+ if pix.n - pix.alpha > 3: # Handle CMYK
182
+ pix = fitz.Pixmap(fitz.csRGB, pix)
183
+
184
+ img_pil = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
185
+
186
+ # Preprocessing for Tesseract (as was in the original code)
187
+ img_pil = img_pil.convert('L')
188
+ img_pil = ImageEnhance.Contrast(img_pil).enhance(2.0)
189
+ img_pil = ImageEnhance.Sharpness(img_pil).enhance(2.0)
190
+
191
+ # Run Tesseract
192
+ ocr_data = pytesseract.image_to_data(img_pil, output_type=pytesseract.Output.DICT)
193
+
194
+ ocr_tokens = []
195
+ for i in range(len(ocr_data['text'])):
196
+ word = ocr_data['text'][i]
197
+ conf = ocr_data['conf'][i]
198
+
199
+ # Use only words with reasonable confidence
200
+ if word.strip() and int(conf) > 50:
201
+ # Get Tesseract's raw pixel bounding box
202
+ left = ocr_data['left'][i]
203
+ top = ocr_data['top'][i]
204
+ width = ocr_data['width'][i]
205
+ height = ocr_data['height'][i]
206
+
207
+ # Convert pixel bbox back to original PDF coordinate system
208
+ scale = page_width / pix.width
209
+
210
+ raw_bbox = [
211
+ left * scale,
212
+ top * scale,
213
+ (left + width) * scale,
214
+ (top + height) * scale
215
+ ]
216
+
217
+ # Normalize bbox
218
+ normalized_bbox = [
219
+ (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
220
+ (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
221
+ (raw_bbox[2] / page_width) * BBOX_NORM_CONSTANT,
222
+ (raw_bbox[3] / page_height) * BBOX_NORM_CONSTANT
223
+ ]
224
+
225
+ ocr_tokens.append({
226
+ "word": word,
227
+ "raw_bbox": [int(b) for b in raw_bbox],
228
+ "normalized_bbox": [int(b) for b in normalized_bbox]
229
+ })
230
+
231
+ return ocr_tokens
232
+
233
+ except Exception as e:
234
+ # Note: 'page.number' might not be available if not running in a loop context
235
+ print(f"OCR fallback failed: {e}")
236
+ return []
237
+
238
+
239
+ def extract_tokens_from_pdf_fitz_with_ocr(pdf_path: str) -> List[Dict[str, Any]]:
240
+ # ... (Your extract_tokens_from_pdf_fitz_with_ocr implementation)
241
+ """
242
+ Extracts words and their raw bounding boxes using PyMuPDF (fitz) text layer
243
+ and falls back to OCR if no text is found.
244
+ """
245
+ all_tokens = []
246
+ try:
247
+ doc = fitz.open(pdf_path)
248
+ for page_num in tqdm(range(len(doc)), desc="PDF Page Processing"):
249
+ page = doc.load_page(page_num)
250
+ page_width, page_height = page.rect.width, page.rect.height
251
+ page_tokens = []
252
+
253
+ # 1. Primary Extraction: Use PyMuPDF's word structure (fitz.Page.get_text("words"))
254
+ # word_list format: (x0, y0, x1, y1, word, ...)
255
+ word_list = page.get_text("words", sort=True)
256
+
257
+ if word_list:
258
+ for word_data in word_list:
259
+ word = word_data[4]
260
+ raw_bbox = word_data[:4]
261
+
262
+ # Normalize bboxes
263
+ normalized_bbox = [
264
+ (raw_bbox[0] / page_width) * BBOX_NORM_CONSTANT,
265
+ (raw_bbox[1] / page_height) * BBOX_NORM_CONSTANT,
266
+ (raw_bbox[2] / page_width) * BBOX_NORM_CONSTANT,
267
+ (raw_bbox[3] / page_height) * BBOX_NORM_CONSTANT
268
+ ]
269
+
270
+ page_tokens.append({
271
+ "word": word,
272
+ "raw_bbox": [int(b) for b in raw_bbox],
273
+ "normalized_bbox": [int(b) for b in normalized_bbox]
274
+ })
275
+
276
+ # 2. OCR Fallback
277
+ if not page_tokens:
278
+ print(f" (Page {page_num + 1}) No text layer found. Running OCR...")
279
+ page_tokens = ocr_fallback_page(page, page_width, page_height)
280
+
281
+ all_tokens.extend(page_tokens)
282
+
283
+ doc.close()
284
+ except Exception as e:
285
+ raise RuntimeError(f"Error opening or processing PDF with fitz/OCR: {e}")
286
+
287
+ return all_tokens
288
+
289
+
290
+ extract_tokens_from_pdf = extract_tokens_from_pdf_fitz_with_ocr
291
+
292
+
293
+ def preprocess_and_collate_tokens(all_tokens: List[Dict[str, Any]], word_vocab: Vocab, char_vocab: Vocab,
294
+ chunk_size: int) -> List[Dict[str, Any]]:
295
+ # ... (Your preprocess_and_collate_tokens implementation)
296
+ """
297
+ Chunks the token list, converts to IDs, and prepares batches for inference. (Unchanged)
298
+ """
299
+ all_batches = []
300
+
301
+ for i in range(0, len(all_tokens), chunk_size):
302
+ chunk = all_tokens[i:i + chunk_size]
303
+ if not chunk: continue
304
+
305
+ words = [t["word"] for t in chunk]
306
+ bboxes_norm = [t["normalized_bbox"] for t in chunk]
307
+
308
+ # Convert to IDs
309
+ word_ids = [word_vocab[w] for w in words]
310
+
311
+ char_ids = []
312
+ for w in words:
313
+ chs = [char_vocab[ch] for ch in w[:MAX_CHAR_LEN]]
314
+ if len(chs) < MAX_CHAR_LEN:
315
+ pad_index = char_vocab.stoi.get(char_vocab.pad_token, 0)
316
+ chs += [pad_index] * (MAX_CHAR_LEN - len(chs))
317
+ char_ids.append(chs)
318
+
319
+ # Create padded tensors (using single-sample batches)
320
+ word_pad = torch.LongTensor([word_ids]).to(DEVICE)
321
+ char_pad = torch.LongTensor([char_ids]).to(DEVICE)
322
+
323
+ # Final normalization to [0, 1] range before feeding to the model
324
+ bbox_pad = torch.FloatTensor([bboxes_norm]).to(DEVICE) / BBOX_NORM_CONSTANT
325
+ mask = torch.ones(word_pad.size(), dtype=torch.bool).to(DEVICE)
326
+
327
+ all_batches.append({
328
+ "words": word_pad,
329
+ "chars": char_pad,
330
+ "bboxes": bbox_pad,
331
+ "mask": mask,
332
+ "original_tokens": chunk # Keep the original data for output formatting
333
+ })
334
+
335
+ return all_batches
336
+
337
+
338
+ # =========================================================
339
+ # 3. Model Loading and Caching (Crucial for Gradio performance)
340
+ # =========================================================
341
+
342
+ # Cache the model and vocabs globally so they are loaded only ONCE when the app starts.
343
+ # This avoids reloading the model on every user request, which is vital for speed.
344
+ try:
345
+ WORD_VOCAB, CHAR_VOCAB = load_vocabs(VOCAB_FILE)
346
+ MODEL = MCQTagger(len(WORD_VOCAB), len(CHAR_VOCAB), len(LABELS)).to(DEVICE)
347
+ MODEL.load_state_dict(torch.load(MODEL_FILE, map_location=DEVICE))
348
+ MODEL.eval()
349
+ print("✅ Model and Vocabs loaded successfully (Cached).")
350
+ except Exception as e:
351
+ MODEL = None
352
+ print(f"❌ Initial Model/Vocab Load Failure: {e}")
353
+ print("The Gradio demo will not function until model_CAT.pt and vocabs_CAT.pkl are in the root directory.")
354
+
355
+
356
+ # =========================================================
357
+ # 4. The Gradio Inference Wrapper Function
358
+ # =========================================================
359
+
360
+ def gradio_inference_wrapper(pdf_file: str) -> Tuple[str, List[Dict[str, Any]]]:
361
+ """
362
+ Wraps the entire inference pipeline for the Gradio Interface.
363
+
364
+ Args:
365
+ pdf_file: The path to the temporary PDF file uploaded by the user (a string).
366
+
367
+ Returns:
368
+ A tuple of (str, List[Dict[str, Any]]): A status message and the raw predictions.
369
+ """
370
+ if MODEL is None:
371
+ return "❌ ERROR: Model failed to load on startup. Check 'model_CAT.pt' and 'vocabs_CAT.pkl'.", []
372
+
373
+ pdf_path = pdf_file
374
+
375
+ try:
376
+ # 1. Extract Tokens
377
+ all_tokens = extract_tokens_from_pdf(pdf_path)
378
+ except RuntimeError as e:
379
+ return f"❌ PDF Processing Error: {e}", []
380
+
381
+ if not all_tokens:
382
+ return "❌ ERROR: No tokens were extracted from the PDF, even after OCR fallback.", []
383
+
384
+ # 2. Preprocess and Batch
385
+ batches = preprocess_and_collate_tokens(all_tokens, WORD_VOCAB, CHAR_VOCAB, chunk_size=INFERENCE_CHUNK_SIZE)
386
+
387
+ # 3. Run Inference
388
+ all_predictions = []
389
+ with torch.no_grad():
390
+ for batch in batches:
391
+ words, chars, bboxes, mask = (batch[k] for k in ["words", "chars", "bboxes", "mask"])
392
+
393
+ preds_batch = MODEL(words, chars, bboxes, mask)
394
+ predictions = preds_batch[0]
395
+
396
+ original_tokens = batch["original_tokens"]
397
+
398
+ for token_data, pred_idx in zip(original_tokens, predictions):
399
+ all_predictions.append({
400
+ "word": token_data["word"],
401
+ "bbox": token_data["raw_bbox"],
402
+ "predicted_label": IDX2LABEL[pred_idx]
403
+ })
404
+
405
+ status_message = f"✅ Inference complete. Total tokens predicted: {len(all_predictions)}"
406
+
407
+ # Gradio will display the JSON output prettified
408
+ return status_message, all_predictions
409
+
410
+
411
+ # =========================================================
412
+ # 5. Define and Launch the Gradio Interface
413
+ # =========================================================
414
+
415
+ if __name__ == "__main__":
416
+ title = "MCQ Document Structure Tagger (Bi-LSTM-CRF)"
417
+ description = "Upload a PDF document (e.g., an MCQ paper). The model will tokenize the text, run inference to predict BIO-tags (B-QUESTION, I-OPTION, B-ANSWER, etc.) for each word, and return the raw JSON predictions."
418
+
419
+ # Define the Gradio Interface
420
+ demo = gr.Interface(
421
+ fn=gradio_inference_wrapper,
422
+ inputs=gr.File(label="Upload PDF Document", file_types=['pdf']),
423
+ outputs=[
424
+ gr.Textbox(label="Status Message", interactive=False),
425
+ gr.JSON(label="Raw BIO Tagging Predictions (JSON)", show_label=True)
426
+ ],
427
+ title=title,
428
+ description=description,
429
+ allow_flagging="never",
430
+ # Set a reasonable concurrency limit (number of simultaneous users) for a CPU/small GPU Space
431
+ concurrency_limit=2
432
+ )
433
+
434
+ # Launch the demo (Hugging Face Spaces automatically calls launch() internally)
435
+ demo.launch()
model_CAT.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7e571ec922de9e9d5095e3a2ef6b670895e1947c5be09db7c1112a49528ceda
3
+ size 15461951
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ PyMuPDF
4
+ pytesseract
5
+ torch-crf
6
+ Pillow
vocabs_CAT.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ace7379c6800c1f13f3859c7181b9be2a0d539debe762cf83739a93c20fb7f70
3
+ size 209360