Update app.py
Browse files
app.py
CHANGED
|
@@ -553,16 +553,16 @@
|
|
| 553 |
|
| 554 |
# demo.launch(show_error=True)
|
| 555 |
|
| 556 |
-
|
| 557 |
import os
|
| 558 |
import json
|
| 559 |
import pickle
|
|
|
|
|
|
|
| 560 |
from typing import List, Dict, Any, Tuple
|
| 561 |
from collections import Counter
|
| 562 |
import torch
|
| 563 |
import torch.nn as nn
|
| 564 |
import torch.nn.functional as F
|
| 565 |
-
import re
|
| 566 |
from tqdm import tqdm
|
| 567 |
import gradio as gr
|
| 568 |
import fitz # PyMuPDF
|
|
@@ -570,7 +570,6 @@ import sys
|
|
| 570 |
from types import ModuleType
|
| 571 |
|
| 572 |
# --- 1. CRITICAL: MOCK THE TRAINING MODULE ---
|
| 573 |
-
# We create a fake module to satisfy pickle/torch.load
|
| 574 |
train_mod = ModuleType("train_model")
|
| 575 |
sys.modules["train_model"] = train_mod
|
| 576 |
|
|
@@ -581,7 +580,6 @@ except ImportError:
|
|
| 581 |
try:
|
| 582 |
from TorchCRF import CRF
|
| 583 |
except ImportError:
|
| 584 |
-
# Fallback if libraries are missing (prevents crash, but model won't load)
|
| 585 |
class CRF(nn.Module):
|
| 586 |
def __init__(self, *args, **kwargs): super().__init__()
|
| 587 |
|
|
@@ -608,7 +606,7 @@ LABELS = [
|
|
| 608 |
]
|
| 609 |
IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
|
| 610 |
|
| 611 |
-
# --- 4. CLASSES
|
| 612 |
|
| 613 |
class Vocab:
|
| 614 |
def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
|
|
@@ -696,26 +694,23 @@ class MCQTagger(nn.Module):
|
|
| 696 |
emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
|
| 697 |
return self.crf.viterbi_decode(emissions, mask=mask)
|
| 698 |
|
| 699 |
-
#
|
| 700 |
-
# This tells pickle: "When you look for 'train_model.Vocab', look here instead."
|
| 701 |
train_mod.Vocab = Vocab
|
| 702 |
train_mod.MCQTagger = MCQTagger
|
| 703 |
train_mod.CharCNNEncoder = CharCNNEncoder
|
| 704 |
train_mod.SpatialAttention = SpatialAttention
|
| 705 |
|
| 706 |
-
# ---
|
|
|
|
| 707 |
def extract_spatial_features(tokens, idx):
|
| 708 |
curr = tokens[idx]
|
| 709 |
f = []
|
| 710 |
-
# Vertical gaps
|
| 711 |
if idx < len(tokens)-1: f.append(min((tokens[idx+1]['y0'] - curr['y1'])/100.0, 1.0))
|
| 712 |
else: f.append(0.0)
|
| 713 |
if idx > 0: f.append(min((curr['y0'] - tokens[idx-1]['y1'])/100.0, 1.0))
|
| 714 |
else: f.append(0.0)
|
| 715 |
-
# Positioning
|
| 716 |
f.extend([curr['x0']/1000.0, (curr['x1']-curr['x0'])/1000.0, (curr['y1']-curr['y0'])/1000.0])
|
| 717 |
f.extend([(curr['x0']+curr['x1'])/2000.0, (curr['y0']+curr['y1'])/2000.0, curr['x0']/1000.0])
|
| 718 |
-
# Ratio & Alignment
|
| 719 |
f.append(min(((curr['x1']-curr['x0'])/max((curr['y1']-curr['y0']),1.0))/10.0, 1.0))
|
| 720 |
if idx > 0: f.append(float(abs(curr['x0'] - tokens[idx-1]['x0']) < 5))
|
| 721 |
else: f.append(0.0)
|
|
@@ -733,14 +728,11 @@ def extract_context_features(tokens, idx, window=3):
|
|
| 733 |
res = check_p(i)
|
| 734 |
prev_res = [max(prev_res[j], res[j]) for j in range(3)]
|
| 735 |
f.extend(prev_res)
|
| 736 |
-
|
| 737 |
next_res = [0.0, 0.0, 0.0]
|
| 738 |
for i in range(idx+1, min(len(tokens), idx+window+1)):
|
| 739 |
res = check_p(i)
|
| 740 |
next_res = [max(next_res[j], res[j]) for j in range(3)]
|
| 741 |
f.extend(next_res)
|
| 742 |
-
|
| 743 |
-
# Distances
|
| 744 |
dq, dopt = 1.0, 1.0
|
| 745 |
for i in range(idx+1, min(len(tokens), idx+window+1)):
|
| 746 |
t = tokens[i]['text'].lower().strip()
|
|
@@ -749,17 +741,172 @@ def extract_context_features(tokens, idx, window=3):
|
|
| 749 |
f.extend([dq, dopt])
|
| 750 |
return f
|
| 751 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 752 |
# --- 7. INFERENCE WRAPPER ---
|
|
|
|
| 753 |
def gradio_inference(pdf_file):
|
| 754 |
if not os.path.exists(MODEL_FILE) or not os.path.exists(VOCAB_FILE):
|
| 755 |
return "❌ Missing model/vocab files.", []
|
| 756 |
|
| 757 |
try:
|
| 758 |
-
# Load
|
| 759 |
with open(VOCAB_FILE, "rb") as f:
|
| 760 |
word_vocab, char_vocab = pickle.load(f)
|
| 761 |
|
| 762 |
-
# Load Model
|
| 763 |
model = MCQTagger(len(word_vocab), len(char_vocab), len(LABELS)).to(DEVICE)
|
| 764 |
state_dict = torch.load(MODEL_FILE, map_location=DEVICE)
|
| 765 |
model.load_state_dict(state_dict if isinstance(state_dict, dict) else state_dict.state_dict())
|
|
@@ -776,19 +923,17 @@ def gradio_inference(pdf_file):
|
|
| 776 |
|
| 777 |
if not all_tokens: return "❌ No text found.", []
|
| 778 |
|
| 779 |
-
#
|
| 780 |
for i in range(len(all_tokens)):
|
| 781 |
all_tokens[i]['spatial_features'] = extract_spatial_features(all_tokens, i)
|
| 782 |
all_tokens[i]['context_features'] = extract_context_features(all_tokens, i)
|
| 783 |
|
| 784 |
# Predict
|
| 785 |
-
|
| 786 |
for i in range(0, len(all_tokens), INFERENCE_CHUNK_SIZE):
|
| 787 |
chunk = all_tokens[i : i + INFERENCE_CHUNK_SIZE]
|
| 788 |
|
| 789 |
-
# Prepare Inputs
|
| 790 |
w_ids = torch.LongTensor([[word_vocab[t['text']] for t in chunk]]).to(DEVICE)
|
| 791 |
-
|
| 792 |
c_ids_list = []
|
| 793 |
for t in chunk:
|
| 794 |
chars = [char_vocab[c] for c in t['text'][:MAX_CHAR_LEN]]
|
|
@@ -804,19 +949,24 @@ def gradio_inference(pdf_file):
|
|
| 804 |
with torch.no_grad():
|
| 805 |
preds = model(w_ids, c_ids, bboxes, s_feats, c_feats, mask)[0]
|
| 806 |
for t, p in zip(chunk, preds):
|
| 807 |
-
|
|
|
|
| 808 |
|
| 809 |
-
|
|
|
|
|
|
|
|
|
|
| 810 |
|
| 811 |
except Exception as e:
|
| 812 |
import traceback
|
|
|
|
| 813 |
return f"❌ Error: {str(e)}", []
|
| 814 |
|
| 815 |
# --- 8. UI ---
|
| 816 |
demo = gr.Interface(
|
| 817 |
fn=gradio_inference,
|
| 818 |
inputs=gr.File(label="Upload PDF"),
|
| 819 |
-
outputs=[gr.Textbox(label="Status"), gr.JSON(label="
|
| 820 |
title="MCQ Enhanced Tagger"
|
| 821 |
)
|
| 822 |
|
|
|
|
| 553 |
|
| 554 |
# demo.launch(show_error=True)
|
| 555 |
|
|
|
|
| 556 |
import os
|
| 557 |
import json
|
| 558 |
import pickle
|
| 559 |
+
import time
|
| 560 |
+
import re
|
| 561 |
from typing import List, Dict, Any, Tuple
|
| 562 |
from collections import Counter
|
| 563 |
import torch
|
| 564 |
import torch.nn as nn
|
| 565 |
import torch.nn.functional as F
|
|
|
|
| 566 |
from tqdm import tqdm
|
| 567 |
import gradio as gr
|
| 568 |
import fitz # PyMuPDF
|
|
|
|
| 570 |
from types import ModuleType
|
| 571 |
|
| 572 |
# --- 1. CRITICAL: MOCK THE TRAINING MODULE ---
|
|
|
|
| 573 |
train_mod = ModuleType("train_model")
|
| 574 |
sys.modules["train_model"] = train_mod
|
| 575 |
|
|
|
|
| 580 |
try:
|
| 581 |
from TorchCRF import CRF
|
| 582 |
except ImportError:
|
|
|
|
| 583 |
class CRF(nn.Module):
|
| 584 |
def __init__(self, *args, **kwargs): super().__init__()
|
| 585 |
|
|
|
|
| 606 |
]
|
| 607 |
IDX2LABEL = {i: l for i, l in enumerate(LABELS)}
|
| 608 |
|
| 609 |
+
# --- 4. CLASSES ---
|
| 610 |
|
| 611 |
class Vocab:
|
| 612 |
def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
|
|
|
|
| 694 |
emissions = self.ff(torch.cat([lstm_out, attn_out], dim=-1))
|
| 695 |
return self.crf.viterbi_decode(emissions, mask=mask)
|
| 696 |
|
| 697 |
+
# Link classes to the fake module
|
|
|
|
| 698 |
train_mod.Vocab = Vocab
|
| 699 |
train_mod.MCQTagger = MCQTagger
|
| 700 |
train_mod.CharCNNEncoder = CharCNNEncoder
|
| 701 |
train_mod.SpatialAttention = SpatialAttention
|
| 702 |
|
| 703 |
+
# --- 5. FEATURE HELPERS ---
|
| 704 |
+
|
| 705 |
def extract_spatial_features(tokens, idx):
|
| 706 |
curr = tokens[idx]
|
| 707 |
f = []
|
|
|
|
| 708 |
if idx < len(tokens)-1: f.append(min((tokens[idx+1]['y0'] - curr['y1'])/100.0, 1.0))
|
| 709 |
else: f.append(0.0)
|
| 710 |
if idx > 0: f.append(min((curr['y0'] - tokens[idx-1]['y1'])/100.0, 1.0))
|
| 711 |
else: f.append(0.0)
|
|
|
|
| 712 |
f.extend([curr['x0']/1000.0, (curr['x1']-curr['x0'])/1000.0, (curr['y1']-curr['y0'])/1000.0])
|
| 713 |
f.extend([(curr['x0']+curr['x1'])/2000.0, (curr['y0']+curr['y1'])/2000.0, curr['x0']/1000.0])
|
|
|
|
| 714 |
f.append(min(((curr['x1']-curr['x0'])/max((curr['y1']-curr['y0']),1.0))/10.0, 1.0))
|
| 715 |
if idx > 0: f.append(float(abs(curr['x0'] - tokens[idx-1]['x0']) < 5))
|
| 716 |
else: f.append(0.0)
|
|
|
|
| 728 |
res = check_p(i)
|
| 729 |
prev_res = [max(prev_res[j], res[j]) for j in range(3)]
|
| 730 |
f.extend(prev_res)
|
|
|
|
| 731 |
next_res = [0.0, 0.0, 0.0]
|
| 732 |
for i in range(idx+1, min(len(tokens), idx+window+1)):
|
| 733 |
res = check_p(i)
|
| 734 |
next_res = [max(next_res[j], res[j]) for j in range(3)]
|
| 735 |
f.extend(next_res)
|
|
|
|
|
|
|
| 736 |
dq, dopt = 1.0, 1.0
|
| 737 |
for i in range(idx+1, min(len(tokens), idx+window+1)):
|
| 738 |
t = tokens[i]['text'].lower().strip()
|
|
|
|
| 741 |
f.extend([dq, dopt])
|
| 742 |
return f
|
| 743 |
|
| 744 |
+
# --- 6. STRUCTURING LOGIC (Injected) ---
|
| 745 |
+
|
| 746 |
+
def convert_predictions_to_structured(predictions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 747 |
+
"""
|
| 748 |
+
Converts a flat list of predictions [{'word':..., 'predicted_label':...}]
|
| 749 |
+
into structured JSON, implementing the specific logic provided.
|
| 750 |
+
"""
|
| 751 |
+
print("--- STARTING BIO TO STRUCTURED JSON DECODING ---")
|
| 752 |
+
start_time = time.time()
|
| 753 |
+
|
| 754 |
+
total_words = len(predictions)
|
| 755 |
+
structured_data = []
|
| 756 |
+
current_item = None
|
| 757 |
+
current_option_key = None
|
| 758 |
+
current_passage_buffer = []
|
| 759 |
+
current_text_buffer = []
|
| 760 |
+
first_question_started = False
|
| 761 |
+
last_entity_type = None
|
| 762 |
+
just_finished_i_option = False
|
| 763 |
+
is_in_new_passage = False
|
| 764 |
+
|
| 765 |
+
def finalize_passage_to_item(item, passage_buffer):
|
| 766 |
+
if passage_buffer:
|
| 767 |
+
passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip()
|
| 768 |
+
if item.get('passage'):
|
| 769 |
+
item['passage'] += ' ' + passage_text
|
| 770 |
+
else:
|
| 771 |
+
item['passage'] = passage_text
|
| 772 |
+
passage_buffer.clear()
|
| 773 |
+
|
| 774 |
+
for idx, item in enumerate(predictions):
|
| 775 |
+
word = item['word']
|
| 776 |
+
label = item['predicted_label']
|
| 777 |
+
entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None
|
| 778 |
+
current_text_buffer.append(word)
|
| 779 |
+
|
| 780 |
+
previous_entity_type = last_entity_type
|
| 781 |
+
is_passage_label = (entity_type == 'PASSAGE')
|
| 782 |
+
|
| 783 |
+
if not first_question_started:
|
| 784 |
+
if label != 'B-QUESTION' and not is_passage_label:
|
| 785 |
+
just_finished_i_option = False
|
| 786 |
+
is_in_new_passage = False
|
| 787 |
+
continue
|
| 788 |
+
if is_passage_label:
|
| 789 |
+
current_passage_buffer.append(word)
|
| 790 |
+
last_entity_type = 'PASSAGE'
|
| 791 |
+
just_finished_i_option = False
|
| 792 |
+
is_in_new_passage = False
|
| 793 |
+
continue
|
| 794 |
+
|
| 795 |
+
if label == 'B-QUESTION':
|
| 796 |
+
if not first_question_started:
|
| 797 |
+
header_text = ' '.join(current_text_buffer[:-1]).strip()
|
| 798 |
+
if header_text or current_passage_buffer:
|
| 799 |
+
metadata_item = {'type': 'METADATA', 'passage': ''}
|
| 800 |
+
finalize_passage_to_item(metadata_item, current_passage_buffer)
|
| 801 |
+
if header_text: metadata_item['text'] = header_text
|
| 802 |
+
structured_data.append(metadata_item)
|
| 803 |
+
first_question_started = True
|
| 804 |
+
current_text_buffer = [word]
|
| 805 |
+
|
| 806 |
+
if current_item is not None:
|
| 807 |
+
finalize_passage_to_item(current_item, current_passage_buffer)
|
| 808 |
+
current_item['text'] = ' '.join(current_text_buffer[:-1]).strip()
|
| 809 |
+
structured_data.append(current_item)
|
| 810 |
+
current_text_buffer = [word]
|
| 811 |
+
|
| 812 |
+
current_item = {
|
| 813 |
+
'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': ''
|
| 814 |
+
}
|
| 815 |
+
current_option_key = None
|
| 816 |
+
last_entity_type = 'QUESTION'
|
| 817 |
+
just_finished_i_option = False
|
| 818 |
+
is_in_new_passage = False
|
| 819 |
+
continue
|
| 820 |
+
|
| 821 |
+
if current_item is not None:
|
| 822 |
+
if is_in_new_passage:
|
| 823 |
+
if 'new_passage' not in current_item:
|
| 824 |
+
current_item['new_passage'] = word
|
| 825 |
+
else:
|
| 826 |
+
current_item['new_passage'] += f' {word}'
|
| 827 |
+
|
| 828 |
+
if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'):
|
| 829 |
+
is_in_new_passage = False
|
| 830 |
+
|
| 831 |
+
if label.startswith(('B-', 'I-')):
|
| 832 |
+
last_entity_type = entity_type
|
| 833 |
+
continue
|
| 834 |
+
|
| 835 |
+
is_in_new_passage = False
|
| 836 |
+
|
| 837 |
+
if label.startswith('B-'):
|
| 838 |
+
if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']:
|
| 839 |
+
finalize_passage_to_item(current_item, current_passage_buffer)
|
| 840 |
+
current_passage_buffer = []
|
| 841 |
+
|
| 842 |
+
last_entity_type = entity_type
|
| 843 |
+
|
| 844 |
+
if entity_type == 'PASSAGE':
|
| 845 |
+
if previous_entity_type == 'OPTION' and just_finished_i_option:
|
| 846 |
+
current_item['new_passage'] = word
|
| 847 |
+
is_in_new_passage = True
|
| 848 |
+
else:
|
| 849 |
+
current_passage_buffer.append(word)
|
| 850 |
+
|
| 851 |
+
elif entity_type == 'OPTION':
|
| 852 |
+
current_option_key = word
|
| 853 |
+
current_item['options'][current_option_key] = word
|
| 854 |
+
just_finished_i_option = False
|
| 855 |
+
|
| 856 |
+
elif entity_type == 'ANSWER':
|
| 857 |
+
current_item['answer'] = word
|
| 858 |
+
current_option_key = None
|
| 859 |
+
just_finished_i_option = False
|
| 860 |
+
|
| 861 |
+
elif entity_type == 'QUESTION':
|
| 862 |
+
current_item['question'] += f' {word}'
|
| 863 |
+
just_finished_i_option = False
|
| 864 |
+
|
| 865 |
+
elif label.startswith('I-'):
|
| 866 |
+
if entity_type == 'QUESTION':
|
| 867 |
+
current_item['question'] += f' {word}'
|
| 868 |
+
elif entity_type == 'PASSAGE':
|
| 869 |
+
if previous_entity_type == 'OPTION' and just_finished_i_option:
|
| 870 |
+
current_item['new_passage'] = word
|
| 871 |
+
is_in_new_passage = True
|
| 872 |
+
else:
|
| 873 |
+
if not current_passage_buffer: last_entity_type = 'PASSAGE'
|
| 874 |
+
current_passage_buffer.append(word)
|
| 875 |
+
elif entity_type == 'OPTION' and current_option_key is not None:
|
| 876 |
+
current_item['options'][current_option_key] += f' {word}'
|
| 877 |
+
just_finished_i_option = True
|
| 878 |
+
elif entity_type == 'ANSWER':
|
| 879 |
+
current_item['answer'] += f' {word}'
|
| 880 |
+
|
| 881 |
+
just_finished_i_option = (entity_type == 'OPTION')
|
| 882 |
+
|
| 883 |
+
elif label == 'O':
|
| 884 |
+
pass
|
| 885 |
+
|
| 886 |
+
if current_item is not None:
|
| 887 |
+
finalize_passage_to_item(current_item, current_passage_buffer)
|
| 888 |
+
current_item['text'] = ' '.join(current_text_buffer).strip()
|
| 889 |
+
structured_data.append(current_item)
|
| 890 |
+
|
| 891 |
+
for item in structured_data:
|
| 892 |
+
item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip()
|
| 893 |
+
if 'new_passage' in item:
|
| 894 |
+
item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip()
|
| 895 |
+
|
| 896 |
+
print(f"✅ Decoding Complete. Total time: {time.time() - start_time:.2f}s")
|
| 897 |
+
return structured_data
|
| 898 |
+
|
| 899 |
# --- 7. INFERENCE WRAPPER ---
|
| 900 |
+
|
| 901 |
def gradio_inference(pdf_file):
|
| 902 |
if not os.path.exists(MODEL_FILE) or not os.path.exists(VOCAB_FILE):
|
| 903 |
return "❌ Missing model/vocab files.", []
|
| 904 |
|
| 905 |
try:
|
| 906 |
+
# Load Resources
|
| 907 |
with open(VOCAB_FILE, "rb") as f:
|
| 908 |
word_vocab, char_vocab = pickle.load(f)
|
| 909 |
|
|
|
|
| 910 |
model = MCQTagger(len(word_vocab), len(char_vocab), len(LABELS)).to(DEVICE)
|
| 911 |
state_dict = torch.load(MODEL_FILE, map_location=DEVICE)
|
| 912 |
model.load_state_dict(state_dict if isinstance(state_dict, dict) else state_dict.state_dict())
|
|
|
|
| 923 |
|
| 924 |
if not all_tokens: return "❌ No text found.", []
|
| 925 |
|
| 926 |
+
# Features
|
| 927 |
for i in range(len(all_tokens)):
|
| 928 |
all_tokens[i]['spatial_features'] = extract_spatial_features(all_tokens, i)
|
| 929 |
all_tokens[i]['context_features'] = extract_context_features(all_tokens, i)
|
| 930 |
|
| 931 |
# Predict
|
| 932 |
+
raw_predictions = []
|
| 933 |
for i in range(0, len(all_tokens), INFERENCE_CHUNK_SIZE):
|
| 934 |
chunk = all_tokens[i : i + INFERENCE_CHUNK_SIZE]
|
| 935 |
|
|
|
|
| 936 |
w_ids = torch.LongTensor([[word_vocab[t['text']] for t in chunk]]).to(DEVICE)
|
|
|
|
| 937 |
c_ids_list = []
|
| 938 |
for t in chunk:
|
| 939 |
chars = [char_vocab[c] for c in t['text'][:MAX_CHAR_LEN]]
|
|
|
|
| 949 |
with torch.no_grad():
|
| 950 |
preds = model(w_ids, c_ids, bboxes, s_feats, c_feats, mask)[0]
|
| 951 |
for t, p in zip(chunk, preds):
|
| 952 |
+
# NOTE: Structuring logic uses 'predicted_label' key
|
| 953 |
+
raw_predictions.append({"word": t['text'], "predicted_label": IDX2LABEL[p]})
|
| 954 |
|
| 955 |
+
# Structure Output
|
| 956 |
+
structured_json = convert_predictions_to_structured(raw_predictions)
|
| 957 |
+
|
| 958 |
+
return "✅ Processing Complete", structured_json
|
| 959 |
|
| 960 |
except Exception as e:
|
| 961 |
import traceback
|
| 962 |
+
traceback.print_exc()
|
| 963 |
return f"❌ Error: {str(e)}", []
|
| 964 |
|
| 965 |
# --- 8. UI ---
|
| 966 |
demo = gr.Interface(
|
| 967 |
fn=gradio_inference,
|
| 968 |
inputs=gr.File(label="Upload PDF"),
|
| 969 |
+
outputs=[gr.Textbox(label="Status"), gr.JSON(label="Structured Output")],
|
| 970 |
title="MCQ Enhanced Tagger"
|
| 971 |
)
|
| 972 |
|