Update app.py
Browse files
app.py
CHANGED
|
@@ -554,7 +554,6 @@
|
|
| 554 |
# demo.launch(show_error=True)
|
| 555 |
|
| 556 |
|
| 557 |
-
|
| 558 |
import os
|
| 559 |
import json
|
| 560 |
import pickle
|
|
@@ -571,8 +570,7 @@ import sys
|
|
| 571 |
from types import ModuleType
|
| 572 |
|
| 573 |
# --- 1. CRITICAL: MOCK THE TRAINING MODULE ---
|
| 574 |
-
#
|
| 575 |
-
# We create a fake module and inject our local classes into it so torch.load works.
|
| 576 |
train_mod = ModuleType("train_model")
|
| 577 |
sys.modules["train_model"] = train_mod
|
| 578 |
|
|
@@ -580,9 +578,14 @@ sys.modules["train_model"] = train_mod
|
|
| 580 |
try:
|
| 581 |
from torch_crf import CRF
|
| 582 |
except ImportError:
|
| 583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
-
# --- 3. CONFIG
|
| 586 |
MODEL_FILE = "model_enhanced.pt"
|
| 587 |
VOCAB_FILE = "vocabs_enhanced.pkl"
|
| 588 |
DEVICE = torch.device("cpu")
|
|
@@ -605,7 +608,19 @@ LABELS = [
|
|
| 605 |
]
|
| 606 |
IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
|
| 607 |
|
| 608 |
-
# --- 4.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
|
| 610 |
class CharCNNEncoder(nn.Module):
|
| 611 |
def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(2, 3, 4, 5)):
|
|
@@ -681,12 +696,14 @@ class MCQTagger(nn.Module):
|
|
| 681 |
emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
|
| 682 |
return self.crf.viterbi_decode(emissions, mask=mask)
|
| 683 |
|
| 684 |
-
#
|
|
|
|
|
|
|
| 685 |
train_mod.MCQTagger = MCQTagger
|
| 686 |
train_mod.CharCNNEncoder = CharCNNEncoder
|
|
|
|
| 687 |
|
| 688 |
-
# ---
|
| 689 |
-
|
| 690 |
def extract_spatial_features(tokens, idx):
|
| 691 |
curr = tokens[idx]
|
| 692 |
f = []
|
|
@@ -732,22 +749,23 @@ def extract_context_features(tokens, idx, window=3):
|
|
| 732 |
f.extend([dq, dopt])
|
| 733 |
return f
|
| 734 |
|
| 735 |
-
# ---
|
| 736 |
-
|
| 737 |
def gradio_inference(pdf_file):
|
| 738 |
if not os.path.exists(MODEL_FILE) or not os.path.exists(VOCAB_FILE):
|
| 739 |
return "❌ Missing model/vocab files.", []
|
| 740 |
|
| 741 |
try:
|
|
|
|
| 742 |
with open(VOCAB_FILE, "rb") as f:
|
| 743 |
word_vocab, char_vocab = pickle.load(f)
|
| 744 |
|
|
|
|
| 745 |
model = MCQTagger(len(word_vocab), len(char_vocab), len(LABELS)).to(DEVICE)
|
| 746 |
state_dict = torch.load(MODEL_FILE, map_location=DEVICE)
|
| 747 |
model.load_state_dict(state_dict if isinstance(state_dict, dict) else state_dict.state_dict())
|
| 748 |
model.eval()
|
| 749 |
|
| 750 |
-
#
|
| 751 |
doc = fitz.open(pdf_file.name)
|
| 752 |
all_tokens = []
|
| 753 |
for page in doc:
|
|
@@ -756,16 +774,28 @@ def gradio_inference(pdf_file):
|
|
| 756 |
all_tokens.append({'text': text, 'x0': x0, 'y0': y0, 'x1': x1, 'y1': y1})
|
| 757 |
doc.close()
|
| 758 |
|
| 759 |
-
|
|
|
|
|
|
|
| 760 |
for i in range(len(all_tokens)):
|
| 761 |
all_tokens[i]['spatial_features'] = extract_spatial_features(all_tokens, i)
|
| 762 |
all_tokens[i]['context_features'] = extract_context_features(all_tokens, i)
|
| 763 |
|
|
|
|
| 764 |
results = []
|
| 765 |
for i in range(0, len(all_tokens), INFERENCE_CHUNK_SIZE):
|
| 766 |
chunk = all_tokens[i : i + INFERENCE_CHUNK_SIZE]
|
|
|
|
|
|
|
| 767 |
w_ids = torch.LongTensor([[word_vocab[t['text']] for t in chunk]]).to(DEVICE)
|
| 768 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
bboxes = torch.FloatTensor([[[t['x0']/1000.0, t['y0']/1000.0, t['x1']/1000.0, t['y1']/1000.0] for t in chunk]]).to(DEVICE)
|
| 770 |
s_feats = torch.FloatTensor([[t['spatial_features'] for t in chunk]]).to(DEVICE)
|
| 771 |
c_feats = torch.FloatTensor([[t['context_features'] for t in chunk]]).to(DEVICE)
|
|
@@ -776,17 +806,18 @@ def gradio_inference(pdf_file):
|
|
| 776 |
for t, p in zip(chunk, preds):
|
| 777 |
results.append({"word": t['text'], "label": IDX2LABEL[p]})
|
| 778 |
|
| 779 |
-
return "✅
|
|
|
|
| 780 |
except Exception as e:
|
|
|
|
| 781 |
return f"❌ Error: {str(e)}", []
|
| 782 |
|
| 783 |
-
# ---
|
| 784 |
-
|
| 785 |
demo = gr.Interface(
|
| 786 |
fn=gradio_inference,
|
| 787 |
-
inputs=gr.File(label="Upload
|
| 788 |
outputs=[gr.Textbox(label="Status"), gr.JSON(label="Predictions")],
|
| 789 |
-
title="
|
| 790 |
)
|
| 791 |
|
| 792 |
if __name__ == "__main__":
|
|
|
|
| 554 |
# demo.launch(show_error=True)
|
| 555 |
|
| 556 |
|
|
|
|
| 557 |
import os
|
| 558 |
import json
|
| 559 |
import pickle
|
|
|
|
| 570 |
from types import ModuleType
|
| 571 |
|
| 572 |
# --- 1. CRITICAL: MOCK THE TRAINING MODULE ---
|
| 573 |
+
# We create a fake module to satisfy pickle/torch.load
|
|
|
|
| 574 |
train_mod = ModuleType("train_model")
|
| 575 |
sys.modules["train_model"] = train_mod
|
| 576 |
|
|
|
|
| 578 |
try:
|
| 579 |
from torch_crf import CRF
|
| 580 |
except ImportError:
|
| 581 |
+
try:
|
| 582 |
+
from TorchCRF import CRF
|
| 583 |
+
except ImportError:
|
| 584 |
+
# Fallback if libraries are missing (prevents crash, but model won't load)
|
| 585 |
+
class CRF(nn.Module):
|
| 586 |
+
def __init__(self, *args, **kwargs): super().__init__()
|
| 587 |
|
| 588 |
+
# --- 3. CONFIG ---
|
| 589 |
MODEL_FILE = "model_enhanced.pt"
|
| 590 |
VOCAB_FILE = "vocabs_enhanced.pkl"
|
| 591 |
DEVICE = torch.device("cpu")
|
|
|
|
| 608 |
]
|
| 609 |
IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
|
| 610 |
|
| 611 |
+
# --- 4. CLASSES (Re-defined to match training) ---
|
| 612 |
+
|
| 613 |
+
class Vocab:
|
| 614 |
+
def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
|
| 615 |
+
self.min_freq = min_freq
|
| 616 |
+
self.unk_token = unk_token
|
| 617 |
+
self.pad_token = pad_token
|
| 618 |
+
self.freq = Counter()
|
| 619 |
+
self.itos = []
|
| 620 |
+
self.stoi = {}
|
| 621 |
+
|
| 622 |
+
def __len__(self): return len(self.itos)
|
| 623 |
+
def __getitem__(self, token): return self.stoi.get(token, self.stoi.get(self.unk_token, 0))
|
| 624 |
|
| 625 |
class CharCNNEncoder(nn.Module):
|
| 626 |
def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(2, 3, 4, 5)):
|
|
|
|
| 696 |
emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
|
| 697 |
return self.crf.viterbi_decode(emissions, mask=mask)
|
| 698 |
|
| 699 |
+
# --- 5. CRITICAL FIX: LINK CLASSES TO FAKE MODULE ---
|
| 700 |
+
# This tells pickle: "When you look for 'train_model.Vocab', look here instead."
|
| 701 |
+
train_mod.Vocab = Vocab
|
| 702 |
train_mod.MCQTagger = MCQTagger
|
| 703 |
train_mod.CharCNNEncoder = CharCNNEncoder
|
| 704 |
+
train_mod.SpatialAttention = SpatialAttention
|
| 705 |
|
| 706 |
+
# --- 6. FEATURE EXTRACTORS ---
|
|
|
|
| 707 |
def extract_spatial_features(tokens, idx):
|
| 708 |
curr = tokens[idx]
|
| 709 |
f = []
|
|
|
|
| 749 |
f.extend([dq, dopt])
|
| 750 |
return f
|
| 751 |
|
| 752 |
+
# --- 7. INFERENCE WRAPPER ---
|
|
|
|
| 753 |
def gradio_inference(pdf_file):
|
| 754 |
if not os.path.exists(MODEL_FILE) or not os.path.exists(VOCAB_FILE):
|
| 755 |
return "❌ Missing model/vocab files.", []
|
| 756 |
|
| 757 |
try:
|
| 758 |
+
# Load Vocab
|
| 759 |
with open(VOCAB_FILE, "rb") as f:
|
| 760 |
word_vocab, char_vocab = pickle.load(f)
|
| 761 |
|
| 762 |
+
# Load Model
|
| 763 |
model = MCQTagger(len(word_vocab), len(char_vocab), len(LABELS)).to(DEVICE)
|
| 764 |
state_dict = torch.load(MODEL_FILE, map_location=DEVICE)
|
| 765 |
model.load_state_dict(state_dict if isinstance(state_dict, dict) else state_dict.state_dict())
|
| 766 |
model.eval()
|
| 767 |
|
| 768 |
+
# Parse PDF
|
| 769 |
doc = fitz.open(pdf_file.name)
|
| 770 |
all_tokens = []
|
| 771 |
for page in doc:
|
|
|
|
| 774 |
all_tokens.append({'text': text, 'x0': x0, 'y0': y0, 'x1': x1, 'y1': y1})
|
| 775 |
doc.close()
|
| 776 |
|
| 777 |
+
if not all_tokens: return "❌ No text found.", []
|
| 778 |
+
|
| 779 |
+
# Feature Extraction
|
| 780 |
for i in range(len(all_tokens)):
|
| 781 |
all_tokens[i]['spatial_features'] = extract_spatial_features(all_tokens, i)
|
| 782 |
all_tokens[i]['context_features'] = extract_context_features(all_tokens, i)
|
| 783 |
|
| 784 |
+
# Predict
|
| 785 |
results = []
|
| 786 |
for i in range(0, len(all_tokens), INFERENCE_CHUNK_SIZE):
|
| 787 |
chunk = all_tokens[i : i + INFERENCE_CHUNK_SIZE]
|
| 788 |
+
|
| 789 |
+
# Prepare Inputs
|
| 790 |
w_ids = torch.LongTensor([[word_vocab[t['text']] for t in chunk]]).to(DEVICE)
|
| 791 |
+
|
| 792 |
+
c_ids_list = []
|
| 793 |
+
for t in chunk:
|
| 794 |
+
chars = [char_vocab[c] for c in t['text'][:MAX_CHAR_LEN]]
|
| 795 |
+
chars += [0] * (MAX_CHAR_LEN - len(chars))
|
| 796 |
+
c_ids_list.append(chars)
|
| 797 |
+
c_ids = torch.LongTensor([c_ids_list]).to(DEVICE)
|
| 798 |
+
|
| 799 |
bboxes = torch.FloatTensor([[[t['x0']/1000.0, t['y0']/1000.0, t['x1']/1000.0, t['y1']/1000.0] for t in chunk]]).to(DEVICE)
|
| 800 |
s_feats = torch.FloatTensor([[t['spatial_features'] for t in chunk]]).to(DEVICE)
|
| 801 |
c_feats = torch.FloatTensor([[t['context_features'] for t in chunk]]).to(DEVICE)
|
|
|
|
| 806 |
for t, p in zip(chunk, preds):
|
| 807 |
results.append({"word": t['text'], "label": IDX2LABEL[p]})
|
| 808 |
|
| 809 |
+
return "✅ Success", results
|
| 810 |
+
|
| 811 |
except Exception as e:
|
| 812 |
+
import traceback
|
| 813 |
return f"❌ Error: {str(e)}", []
|
| 814 |
|
| 815 |
+
# --- 8. UI ---
|
|
|
|
| 816 |
demo = gr.Interface(
|
| 817 |
fn=gradio_inference,
|
| 818 |
+
inputs=gr.File(label="Upload PDF"),
|
| 819 |
outputs=[gr.Textbox(label="Status"), gr.JSON(label="Predictions")],
|
| 820 |
+
title="MCQ Enhanced Tagger"
|
| 821 |
)
|
| 822 |
|
| 823 |
if __name__ == "__main__":
|