heerjtdev commited on
Commit
0ff9ac0
·
verified ·
1 Parent(s): 0f65208

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -23
app.py CHANGED
@@ -553,16 +553,16 @@
553
 
554
  # demo.launch(show_error=True)
555
 
556
-
557
  import os
558
  import json
559
  import pickle
 
 
560
  from typing import List, Dict, Any, Tuple
561
  from collections import Counter
562
  import torch
563
  import torch.nn as nn
564
  import torch.nn.functional as F
565
- import re
566
  from tqdm import tqdm
567
  import gradio as gr
568
  import fitz # PyMuPDF
@@ -570,7 +570,6 @@ import sys
570
  from types import ModuleType
571
 
572
  # --- 1. CRITICAL: MOCK THE TRAINING MODULE ---
573
- # We create a fake module to satisfy pickle/torch.load
574
  train_mod = ModuleType("train_model")
575
  sys.modules["train_model"] = train_mod
576
 
@@ -581,7 +580,6 @@ except ImportError:
581
  try:
582
  from TorchCRF import CRF
583
  except ImportError:
584
- # Fallback if libraries are missing (prevents crash, but model won't load)
585
  class CRF(nn.Module):
586
  def __init__(self, *args, **kwargs): super().__init__()
587
 
@@ -608,7 +606,7 @@ LABELS = [
608
  ]
609
  IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
610
 
611
- # --- 4. CLASSES (Re-defined to match training) ---
612
 
613
  class Vocab:
614
  def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
@@ -696,26 +694,23 @@ class MCQTagger(nn.Module):
696
  emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
697
  return self.crf.viterbi_decode(emissions, mask=mask)
698
 
699
- # --- 5. CRITICAL FIX: LINK CLASSES TO FAKE MODULE ---
700
- # This tells pickle: "When you look for 'train_model.Vocab', look here instead."
701
  train_mod.Vocab = Vocab
702
  train_mod.MCQTagger = MCQTagger
703
  train_mod.CharCNNEncoder = CharCNNEncoder
704
  train_mod.SpatialAttention = SpatialAttention
705
 
706
- # --- 6. FEATURE EXTRACTORS ---
 
707
  def extract_spatial_features(tokens, idx):
708
  curr = tokens[idx]
709
  f = []
710
- # Vertical gaps
711
  if idx < len(tokens)-1: f.append(min((tokens[idx+1]['y0'] - curr['y1'])/100.0, 1.0))
712
  else: f.append(0.0)
713
  if idx > 0: f.append(min((curr['y0'] - tokens[idx-1]['y1'])/100.0, 1.0))
714
  else: f.append(0.0)
715
- # Positioning
716
  f.extend([curr['x0']/1000.0, (curr['x1']-curr['x0'])/1000.0, (curr['y1']-curr['y0'])/1000.0])
717
  f.extend([(curr['x0']+curr['x1'])/2000.0, (curr['y0']+curr['y1'])/2000.0, curr['x0']/1000.0])
718
- # Ratio & Alignment
719
  f.append(min(((curr['x1']-curr['x0'])/max((curr['y1']-curr['y0']),1.0))/10.0, 1.0))
720
  if idx > 0: f.append(float(abs(curr['x0'] - tokens[idx-1]['x0']) < 5))
721
  else: f.append(0.0)
@@ -733,14 +728,11 @@ def extract_context_features(tokens, idx, window=3):
733
  res = check_p(i)
734
  prev_res = [max(prev_res[j], res[j]) for j in range(3)]
735
  f.extend(prev_res)
736
-
737
  next_res = [0.0, 0.0, 0.0]
738
  for i in range(idx+1, min(len(tokens), idx+window+1)):
739
  res = check_p(i)
740
  next_res = [max(next_res[j], res[j]) for j in range(3)]
741
  f.extend(next_res)
742
-
743
- # Distances
744
  dq, dopt = 1.0, 1.0
745
  for i in range(idx+1, min(len(tokens), idx+window+1)):
746
  t = tokens[i]['text'].lower().strip()
@@ -749,17 +741,172 @@ def extract_context_features(tokens, idx, window=3):
749
  f.extend([dq, dopt])
750
  return f
751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
752
  # --- 7. INFERENCE WRAPPER ---
 
753
  def gradio_inference(pdf_file):
754
  if not os.path.exists(MODEL_FILE) or not os.path.exists(VOCAB_FILE):
755
  return "❌ Missing model/vocab files.", []
756
 
757
  try:
758
- # Load Vocab
759
  with open(VOCAB_FILE, "rb") as f:
760
  word_vocab, char_vocab = pickle.load(f)
761
 
762
- # Load Model
763
  model = MCQTagger(len(word_vocab), len(char_vocab), len(LABELS)).to(DEVICE)
764
  state_dict = torch.load(MODEL_FILE, map_location=DEVICE)
765
  model.load_state_dict(state_dict if isinstance(state_dict, dict) else state_dict.state_dict())
@@ -776,19 +923,17 @@ def gradio_inference(pdf_file):
776
 
777
  if not all_tokens: return "❌ No text found.", []
778
 
779
- # Feature Extraction
780
  for i in range(len(all_tokens)):
781
  all_tokens[i]['spatial_features'] = extract_spatial_features(all_tokens, i)
782
  all_tokens[i]['context_features'] = extract_context_features(all_tokens, i)
783
 
784
  # Predict
785
- results = []
786
  for i in range(0, len(all_tokens), INFERENCE_CHUNK_SIZE):
787
  chunk = all_tokens[i : i + INFERENCE_CHUNK_SIZE]
788
 
789
- # Prepare Inputs
790
  w_ids = torch.LongTensor([[word_vocab[t['text']] for t in chunk]]).to(DEVICE)
791
-
792
  c_ids_list = []
793
  for t in chunk:
794
  chars = [char_vocab[c] for c in t['text'][:MAX_CHAR_LEN]]
@@ -804,19 +949,24 @@ def gradio_inference(pdf_file):
804
  with torch.no_grad():
805
  preds = model(w_ids, c_ids, bboxes, s_feats, c_feats, mask)[0]
806
  for t, p in zip(chunk, preds):
807
- results.append({"word": t['text'], "label": IDX2LABEL[p]})
 
808
 
809
- return "✅ Success", results
 
 
 
810
 
811
  except Exception as e:
812
  import traceback
 
813
  return f"❌ Error: {str(e)}", []
814
 
815
  # --- 8. UI ---
816
  demo = gr.Interface(
817
  fn=gradio_inference,
818
  inputs=gr.File(label="Upload PDF"),
819
- outputs=[gr.Textbox(label="Status"), gr.JSON(label="Predictions")],
820
  title="MCQ Enhanced Tagger"
821
  )
822
 
 
553
 
554
  # demo.launch(show_error=True)
555
 
 
556
  import os
557
  import json
558
  import pickle
559
+ import time
560
+ import re
561
  from typing import List, Dict, Any, Tuple
562
  from collections import Counter
563
  import torch
564
  import torch.nn as nn
565
  import torch.nn.functional as F
 
566
  from tqdm import tqdm
567
  import gradio as gr
568
  import fitz # PyMuPDF
 
570
  from types import ModuleType
571
 
572
  # --- 1. CRITICAL: MOCK THE TRAINING MODULE ---
 
573
  train_mod = ModuleType("train_model")
574
  sys.modules["train_model"] = train_mod
575
 
 
580
  try:
581
  from TorchCRF import CRF
582
  except ImportError:
 
583
  class CRF(nn.Module):
584
  def __init__(self, *args, **kwargs): super().__init__()
585
 
 
606
  ]
607
  IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
608
 
609
+ # --- 4. CLASSES ---
610
 
611
  class Vocab:
612
  def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
 
694
  emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
695
  return self.crf.viterbi_decode(emissions, mask=mask)
696
 
697
+ # Link classes to the fake module
 
698
  train_mod.Vocab = Vocab
699
  train_mod.MCQTagger = MCQTagger
700
  train_mod.CharCNNEncoder = CharCNNEncoder
701
  train_mod.SpatialAttention = SpatialAttention
702
 
703
+ # --- 5. FEATURE HELPERS ---
704
+
705
  def extract_spatial_features(tokens, idx):
706
  curr = tokens[idx]
707
  f = []
 
708
  if idx < len(tokens)-1: f.append(min((tokens[idx+1]['y0'] - curr['y1'])/100.0, 1.0))
709
  else: f.append(0.0)
710
  if idx > 0: f.append(min((curr['y0'] - tokens[idx-1]['y1'])/100.0, 1.0))
711
  else: f.append(0.0)
 
712
  f.extend([curr['x0']/1000.0, (curr['x1']-curr['x0'])/1000.0, (curr['y1']-curr['y0'])/1000.0])
713
  f.extend([(curr['x0']+curr['x1'])/2000.0, (curr['y0']+curr['y1'])/2000.0, curr['x0']/1000.0])
 
714
  f.append(min(((curr['x1']-curr['x0'])/max((curr['y1']-curr['y0']),1.0))/10.0, 1.0))
715
  if idx > 0: f.append(float(abs(curr['x0'] - tokens[idx-1]['x0']) < 5))
716
  else: f.append(0.0)
 
728
  res = check_p(i)
729
  prev_res = [max(prev_res[j], res[j]) for j in range(3)]
730
  f.extend(prev_res)
 
731
  next_res = [0.0, 0.0, 0.0]
732
  for i in range(idx+1, min(len(tokens), idx+window+1)):
733
  res = check_p(i)
734
  next_res = [max(next_res[j], res[j]) for j in range(3)]
735
  f.extend(next_res)
 
 
736
  dq, dopt = 1.0, 1.0
737
  for i in range(idx+1, min(len(tokens), idx+window+1)):
738
  t = tokens[i]['text'].lower().strip()
 
741
  f.extend([dq, dopt])
742
  return f
743
 
744
+ # --- 6. STRUCTURING LOGIC (Injected) ---
745
+
746
+ def convert_predictions_to_structured(predictions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
747
+ """
748
+ Converts a flat list of predictions [{'word':..., 'predicted_label':...}]
749
+ into structured JSON, implementing the specific logic provided.
750
+ """
751
+ print("--- STARTING BIO TO STRUCTURED JSON DECODING ---")
752
+ start_time = time.time()
753
+
754
+ total_words = len(predictions)
755
+ structured_data = []
756
+ current_item = None
757
+ current_option_key = None
758
+ current_passage_buffer = []
759
+ current_text_buffer = []
760
+ first_question_started = False
761
+ last_entity_type = None
762
+ just_finished_i_option = False
763
+ is_in_new_passage = False
764
+
765
+ def finalize_passage_to_item(item, passage_buffer):
766
+ if passage_buffer:
767
+ passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
768
+ if item.get('passage'):
769
+ item['passage'] += ' ' + passage_text
770
+ else:
771
+ item['passage'] = passage_text
772
+ passage_buffer.clear()
773
+
774
+ for idx, item in enumerate(predictions):
775
+ word = item['word']
776
+ label = item['predicted_label']
777
+ entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
778
+ current_text_buffer.append(word)
779
+
780
+ previous_entity_type = last_entity_type
781
+ is_passage_label = (entity_type == 'PASSAGE')
782
+
783
+ if not first_question_started:
784
+ if label != 'B-QUESTION' and not is_passage_label:
785
+ just_finished_i_option = False
786
+ is_in_new_passage = False
787
+ continue
788
+ if is_passage_label:
789
+ current_passage_buffer.append(word)
790
+ last_entity_type = 'PASSAGE'
791
+ just_finished_i_option = False
792
+ is_in_new_passage = False
793
+ continue
794
+
795
+ if label == 'B-QUESTION':
796
+ if not first_question_started:
797
+ header_text = ' '.join(current_text_buffer[:-1]).strip()
798
+ if header_text or current_passage_buffer:
799
+ metadata_item = {'type': 'METADATA', 'passage': ''}
800
+ finalize_passage_to_item(metadata_item, current_passage_buffer)
801
+ if header_text: metadata_item['text'] = header_text
802
+ structured_data.append(metadata_item)
803
+ first_question_started = True
804
+ current_text_buffer = [word]
805
+
806
+ if current_item is not None:
807
+ finalize_passage_to_item(current_item, current_passage_buffer)
808
+ current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
809
+ structured_data.append(current_item)
810
+ current_text_buffer = [word]
811
+
812
+ current_item = {
813
+ 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': ''
814
+ }
815
+ current_option_key = None
816
+ last_entity_type = 'QUESTION'
817
+ just_finished_i_option = False
818
+ is_in_new_passage = False
819
+ continue
820
+
821
+ if current_item is not None:
822
+ if is_in_new_passage:
823
+ if 'new_passage' not in current_item:
824
+ current_item['new_passage'] = word
825
+ else:
826
+ current_item['new_passage'] += f' {word}'
827
+
828
+ if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'):
829
+ is_in_new_passage = False
830
+
831
+ if label.startswith(('B-', 'I-')):
832
+ last_entity_type = entity_type
833
+ continue
834
+
835
+ is_in_new_passage = False
836
+
837
+ if label.startswith('B-'):
838
+ if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']:
839
+ finalize_passage_to_item(current_item, current_passage_buffer)
840
+ current_passage_buffer = []
841
+
842
+ last_entity_type = entity_type
843
+
844
+ if entity_type == 'PASSAGE':
845
+ if previous_entity_type == 'OPTION' and just_finished_i_option:
846
+ current_item['new_passage'] = word
847
+ is_in_new_passage = True
848
+ else:
849
+ current_passage_buffer.append(word)
850
+
851
+ elif entity_type == 'OPTION':
852
+ current_option_key = word
853
+ current_item['options'][current_option_key] = word
854
+ just_finished_i_option = False
855
+
856
+ elif entity_type == 'ANSWER':
857
+ current_item['answer'] = word
858
+ current_option_key = None
859
+ just_finished_i_option = False
860
+
861
+ elif entity_type == 'QUESTION':
862
+ current_item['question'] += f' {word}'
863
+ just_finished_i_option = False
864
+
865
+ elif label.startswith('I-'):
866
+ if entity_type == 'QUESTION':
867
+ current_item['question'] += f' {word}'
868
+ elif entity_type == 'PASSAGE':
869
+ if previous_entity_type == 'OPTION' and just_finished_i_option:
870
+ current_item['new_passage'] = word
871
+ is_in_new_passage = True
872
+ else:
873
+ if not current_passage_buffer: last_entity_type = 'PASSAGE'
874
+ current_passage_buffer.append(word)
875
+ elif entity_type == 'OPTION' and current_option_key is not None:
876
+ current_item['options'][current_option_key] += f' {word}'
877
+ just_finished_i_option = True
878
+ elif entity_type == 'ANSWER':
879
+ current_item['answer'] += f' {word}'
880
+
881
+ just_finished_i_option = (entity_type == 'OPTION')
882
+
883
+ elif label == 'O':
884
+ pass
885
+
886
+ if current_item is not None:
887
+ finalize_passage_to_item(current_item, current_passage_buffer)
888
+ current_item['text'] = ' '.join(current_text_buffer).strip()
889
+ structured_data.append(current_item)
890
+
891
+ for item in structured_data:
892
+ item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
893
+ if 'new_passage' in item:
894
+ item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip()
895
+
896
+ print(f"✅ Decoding Complete. Total time: {time.time() - start_time:.2f}s")
897
+ return structured_data
898
+
899
  # --- 7. INFERENCE WRAPPER ---
900
+
901
  def gradio_inference(pdf_file):
902
  if not os.path.exists(MODEL_FILE) or not os.path.exists(VOCAB_FILE):
903
  return "❌ Missing model/vocab files.", []
904
 
905
  try:
906
+ # Load Resources
907
  with open(VOCAB_FILE, "rb") as f:
908
  word_vocab, char_vocab = pickle.load(f)
909
 
 
910
  model = MCQTagger(len(word_vocab), len(char_vocab), len(LABELS)).to(DEVICE)
911
  state_dict = torch.load(MODEL_FILE, map_location=DEVICE)
912
  model.load_state_dict(state_dict if isinstance(state_dict, dict) else state_dict.state_dict())
 
923
 
924
  if not all_tokens: return "❌ No text found.", []
925
 
926
+ # Features
927
  for i in range(len(all_tokens)):
928
  all_tokens[i]['spatial_features'] = extract_spatial_features(all_tokens, i)
929
  all_tokens[i]['context_features'] = extract_context_features(all_tokens, i)
930
 
931
  # Predict
932
+ raw_predictions = []
933
  for i in range(0, len(all_tokens), INFERENCE_CHUNK_SIZE):
934
  chunk = all_tokens[i : i + INFERENCE_CHUNK_SIZE]
935
 
 
936
  w_ids = torch.LongTensor([[word_vocab[t['text']] for t in chunk]]).to(DEVICE)
 
937
  c_ids_list = []
938
  for t in chunk:
939
  chars = [char_vocab[c] for c in t['text'][:MAX_CHAR_LEN]]
 
949
  with torch.no_grad():
950
  preds = model(w_ids, c_ids, bboxes, s_feats, c_feats, mask)[0]
951
  for t, p in zip(chunk, preds):
952
+ # NOTE: Structuring logic uses 'predicted_label' key
953
+ raw_predictions.append({"word": t['text'], "predicted_label": IDX2LABEL[p]})
954
 
955
+ # Structure Output
956
+ structured_json = convert_predictions_to_structured(raw_predictions)
957
+
958
+ return "✅ Processing Complete", structured_json
959
 
960
  except Exception as e:
961
  import traceback
962
+ traceback.print_exc()
963
  return f"❌ Error: {str(e)}", []
964
 
965
  # --- 8. UI ---
966
  demo = gr.Interface(
967
  fn=gradio_inference,
968
  inputs=gr.File(label="Upload PDF"),
969
+ outputs=[gr.Textbox(label="Status"), gr.JSON(label="Structured Output")],
970
  title="MCQ Enhanced Tagger"
971
  )
972