heerjtdev commited on
Commit
0f65208
·
verified ·
1 Parent(s): b386cf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -19
app.py CHANGED
@@ -554,7 +554,6 @@
554
  # demo.launch(show_error=True)
555
 
556
 
557
-
558
  import os
559
  import json
560
  import pickle
@@ -571,8 +570,7 @@ import sys
571
  from types import ModuleType
572
 
573
  # --- 1. CRITICAL: MOCK THE TRAINING MODULE ---
574
- # Your model was saved with references to a 'train_model' module.
575
- # We create a fake module and inject our local classes into it so torch.load works.
576
  train_mod = ModuleType("train_model")
577
  sys.modules["train_model"] = train_mod
578
 
@@ -580,9 +578,14 @@ sys.modules["train_model"] = train_mod
580
  try:
581
  from torch_crf import CRF
582
  except ImportError:
583
- from TorchCRF import CRF
 
 
 
 
 
584
 
585
- # --- 3. CONFIG (Matching Training Script) ---
586
  MODEL_FILE = "model_enhanced.pt"
587
  VOCAB_FILE = "vocabs_enhanced.pkl"
588
  DEVICE = torch.device("cpu")
@@ -605,7 +608,19 @@ LABELS = [
605
  ]
606
  IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
607
 
608
- # --- 4. MODEL ARCHITECTURE (Exact Match to Training Script) ---
 
 
 
 
 
 
 
 
 
 
 
 
609
 
610
  class CharCNNEncoder(nn.Module):
611
  def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(2, 3, 4, 5)):
@@ -681,12 +696,14 @@ class MCQTagger(nn.Module):
681
  emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
682
  return self.crf.viterbi_decode(emissions, mask=mask)
683
 
684
- # Link classes to the fake module
 
 
685
  train_mod.MCQTagger = MCQTagger
686
  train_mod.CharCNNEncoder = CharCNNEncoder
 
687
 
688
- # --- 5. FEATURE EXTRACTION HELPERS (From Training Script) ---
689
-
690
  def extract_spatial_features(tokens, idx):
691
  curr = tokens[idx]
692
  f = []
@@ -732,22 +749,23 @@ def extract_context_features(tokens, idx, window=3):
732
  f.extend([dq, dopt])
733
  return f
734
 
735
- # --- 6. INFERENCE PIPELINE ---
736
-
737
  def gradio_inference(pdf_file):
738
  if not os.path.exists(MODEL_FILE) or not os.path.exists(VOCAB_FILE):
739
  return "❌ Missing model/vocab files.", []
740
 
741
  try:
 
742
  with open(VOCAB_FILE, "rb") as f:
743
  word_vocab, char_vocab = pickle.load(f)
744
 
 
745
  model = MCQTagger(len(word_vocab), len(char_vocab), len(LABELS)).to(DEVICE)
746
  state_dict = torch.load(MODEL_FILE, map_location=DEVICE)
747
  model.load_state_dict(state_dict if isinstance(state_dict, dict) else state_dict.state_dict())
748
  model.eval()
749
 
750
- # Extract tokens
751
  doc = fitz.open(pdf_file.name)
752
  all_tokens = []
753
  for page in doc:
@@ -756,16 +774,28 @@ def gradio_inference(pdf_file):
756
  all_tokens.append({'text': text, 'x0': x0, 'y0': y0, 'x1': x1, 'y1': y1})
757
  doc.close()
758
 
759
- # Generate features
 
 
760
  for i in range(len(all_tokens)):
761
  all_tokens[i]['spatial_features'] = extract_spatial_features(all_tokens, i)
762
  all_tokens[i]['context_features'] = extract_context_features(all_tokens, i)
763
 
 
764
  results = []
765
  for i in range(0, len(all_tokens), INFERENCE_CHUNK_SIZE):
766
  chunk = all_tokens[i : i + INFERENCE_CHUNK_SIZE]
 
 
767
  w_ids = torch.LongTensor([[word_vocab[t['text']] for t in chunk]]).to(DEVICE)
768
- c_ids = torch.LongTensor([[([char_vocab[c] for c in t['text'][:MAX_CHAR_LEN]] + [0]*MAX_CHAR_LEN)[:MAX_CHAR_LEN] for t in chunk]]).to(DEVICE)
 
 
 
 
 
 
 
769
  bboxes = torch.FloatTensor([[[t['x0']/1000.0, t['y0']/1000.0, t['x1']/1000.0, t['y1']/1000.0] for t in chunk]]).to(DEVICE)
770
  s_feats = torch.FloatTensor([[t['spatial_features'] for t in chunk]]).to(DEVICE)
771
  c_feats = torch.FloatTensor([[t['context_features'] for t in chunk]]).to(DEVICE)
@@ -776,17 +806,18 @@ def gradio_inference(pdf_file):
776
  for t, p in zip(chunk, preds):
777
  results.append({"word": t['text'], "label": IDX2LABEL[p]})
778
 
779
- return "✅ Processing Complete", results
 
780
  except Exception as e:
 
781
  return f"❌ Error: {str(e)}", []
782
 
783
- # --- 7. INTERFACE ---
784
-
785
  demo = gr.Interface(
786
  fn=gradio_inference,
787
- inputs=gr.File(label="Upload MCQ PDF"),
788
  outputs=[gr.Textbox(label="Status"), gr.JSON(label="Predictions")],
789
- title="Enhanced MCQ Tagger (Spatial Attention + BiLSTM-CRF)"
790
  )
791
 
792
  if __name__ == "__main__":
 
554
  # demo.launch(show_error=True)
555
 
556
 
 
557
  import os
558
  import json
559
  import pickle
 
570
  from types import ModuleType
571
 
572
  # --- 1. CRITICAL: MOCK THE TRAINING MODULE ---
573
+ # We create a fake module to satisfy pickle/torch.load
 
574
  train_mod = ModuleType("train_model")
575
  sys.modules["train_model"] = train_mod
576
 
 
578
  try:
579
  from torch_crf import CRF
580
  except ImportError:
581
+ try:
582
+ from TorchCRF import CRF
583
+ except ImportError:
584
+ # Fallback if libraries are missing (prevents crash, but model won't load)
585
+ class CRF(nn.Module):
586
+ def __init__(self, *args, **kwargs): super().__init__()
587
 
588
+ # --- 3. CONFIG ---
589
  MODEL_FILE = "model_enhanced.pt"
590
  VOCAB_FILE = "vocabs_enhanced.pkl"
591
  DEVICE = torch.device("cpu")
 
608
  ]
609
  IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
610
 
611
+ # --- 4. CLASSES (Re-defined to match training) ---
612
+
613
+ class Vocab:
614
+ def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
615
+ self.min_freq = min_freq
616
+ self.unk_token = unk_token
617
+ self.pad_token = pad_token
618
+ self.freq = Counter()
619
+ self.itos = []
620
+ self.stoi = {}
621
+
622
+ def __len__(self): return len(self.itos)
623
+ def __getitem__(self, token): return self.stoi.get(token, self.stoi.get(self.unk_token, 0))
624
 
625
  class CharCNNEncoder(nn.Module):
626
  def __init__(self, char_vocab_size, char_emb_dim, out_dim, kernel_sizes=(2, 3, 4, 5)):
 
696
  emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
697
  return self.crf.viterbi_decode(emissions, mask=mask)
698
 
699
+ # --- 5. CRITICAL FIX: LINK CLASSES TO FAKE MODULE ---
700
+ # This tells pickle: "When you look for 'train_model.Vocab', look here instead."
701
+ train_mod.Vocab = Vocab
702
  train_mod.MCQTagger = MCQTagger
703
  train_mod.CharCNNEncoder = CharCNNEncoder
704
+ train_mod.SpatialAttention = SpatialAttention
705
 
706
+ # --- 6. FEATURE EXTRACTORS ---
 
707
  def extract_spatial_features(tokens, idx):
708
  curr = tokens[idx]
709
  f = []
 
749
  f.extend([dq, dopt])
750
  return f
751
 
752
+ # --- 7. INFERENCE WRAPPER ---
 
753
  def gradio_inference(pdf_file):
754
  if not os.path.exists(MODEL_FILE) or not os.path.exists(VOCAB_FILE):
755
  return "❌ Missing model/vocab files.", []
756
 
757
  try:
758
+ # Load Vocab
759
  with open(VOCAB_FILE, "rb") as f:
760
  word_vocab, char_vocab = pickle.load(f)
761
 
762
+ # Load Model
763
  model = MCQTagger(len(word_vocab), len(char_vocab), len(LABELS)).to(DEVICE)
764
  state_dict = torch.load(MODEL_FILE, map_location=DEVICE)
765
  model.load_state_dict(state_dict if isinstance(state_dict, dict) else state_dict.state_dict())
766
  model.eval()
767
 
768
+ # Parse PDF
769
  doc = fitz.open(pdf_file.name)
770
  all_tokens = []
771
  for page in doc:
 
774
  all_tokens.append({'text': text, 'x0': x0, 'y0': y0, 'x1': x1, 'y1': y1})
775
  doc.close()
776
 
777
+ if not all_tokens: return "❌ No text found.", []
778
+
779
+ # Feature Extraction
780
  for i in range(len(all_tokens)):
781
  all_tokens[i]['spatial_features'] = extract_spatial_features(all_tokens, i)
782
  all_tokens[i]['context_features'] = extract_context_features(all_tokens, i)
783
 
784
+ # Predict
785
  results = []
786
  for i in range(0, len(all_tokens), INFERENCE_CHUNK_SIZE):
787
  chunk = all_tokens[i : i + INFERENCE_CHUNK_SIZE]
788
+
789
+ # Prepare Inputs
790
  w_ids = torch.LongTensor([[word_vocab[t['text']] for t in chunk]]).to(DEVICE)
791
+
792
+ c_ids_list = []
793
+ for t in chunk:
794
+ chars = [char_vocab[c] for c in t['text'][:MAX_CHAR_LEN]]
795
+ chars += [0] * (MAX_CHAR_LEN - len(chars))
796
+ c_ids_list.append(chars)
797
+ c_ids = torch.LongTensor([c_ids_list]).to(DEVICE)
798
+
799
  bboxes = torch.FloatTensor([[[t['x0']/1000.0, t['y0']/1000.0, t['x1']/1000.0, t['y1']/1000.0] for t in chunk]]).to(DEVICE)
800
  s_feats = torch.FloatTensor([[t['spatial_features'] for t in chunk]]).to(DEVICE)
801
  c_feats = torch.FloatTensor([[t['context_features'] for t in chunk]]).to(DEVICE)
 
806
  for t, p in zip(chunk, preds):
807
  results.append({"word": t['text'], "label": IDX2LABEL[p]})
808
 
809
+ return "✅ Success", results
810
+
811
  except Exception as e:
812
+ import traceback
813
  return f"❌ Error: {str(e)}", []
814
 
815
+ # --- 8. UI ---
 
816
  demo = gr.Interface(
817
  fn=gradio_inference,
818
+ inputs=gr.File(label="Upload PDF"),
819
  outputs=[gr.Textbox(label="Status"), gr.JSON(label="Predictions")],
820
+ title="MCQ Enhanced Tagger"
821
  )
822
 
823
  if __name__ == "__main__":