heerjtdev commited on
Commit
6a20046
·
verified ·
1 Parent(s): 6ca2c19

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +104 -241
train.py CHANGED
@@ -1,244 +1,107 @@
1
- import json
2
- import argparse
 
 
 
3
  import os
4
- import random
5
- import torch
6
- import torch.nn as nn
7
- from torch.utils.data import Dataset, DataLoader, random_split
8
- from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model
9
- from TorchCRF import CRF
10
- from torch.optim import AdamW
11
- from tqdm import tqdm
12
- from sklearn.metrics import precision_recall_fscore_support
13
-
14
- # --- Configuration ---
15
- MAX_BBOX_DIMENSION = 1000
16
- MAX_SHIFT = 30
17
- AUGMENTATION_FACTOR = 1
18
- BASE_MODEL_ID = "heerjtdev/MLP_LayoutLM"
19
-
20
- # -------------------------
21
- # Step 1: Preprocessing
22
- # -------------------------
23
- def preprocess_labelstudio(input_path, output_path):
24
- with open(input_path, "r", encoding="utf-8") as f:
25
- data = json.load(f)
26
-
27
- processed = []
28
- print(f"🔄 Starting preprocessing of {len(data)} documents...")
29
-
30
- for item in data:
31
- words = item["data"]["original_words"]
32
- bboxes = item["data"]["original_bboxes"]
33
- labels = ["O"] * len(words)
34
-
35
- clamped_bboxes = []
36
- for bbox in bboxes:
37
- x_min, y_min, x_max, y_max = bbox
38
- new_x_min = max(0, min(x_min, 1000))
39
- new_y_min = max(0, min(y_min, 1000))
40
- new_x_max = max(0, min(x_max, 1000))
41
- new_y_max = max(0, min(y_max, 1000))
42
- if new_x_min > new_x_max: new_x_min = new_x_max
43
- if new_y_min > new_y_max: new_y_min = new_y_max
44
- clamped_bboxes.append([new_x_min, new_y_min, new_x_max, new_y_max])
45
-
46
- if "annotations" in item:
47
- for ann in item["annotations"]:
48
- for res in ann["result"]:
49
- if "value" in res and "labels" in res["value"]:
50
- text = res["value"]["text"]
51
- tag = res["value"]["labels"][0]
52
- text_tokens = text.split()
53
- for i in range(len(words) - len(text_tokens) + 1):
54
- if words[i:i + len(text_tokens)] == text_tokens:
55
- labels[i] = f"B-{tag}"
56
- for j in range(1, len(text_tokens)):
57
- labels[i + j] = f"I-{tag}"
58
- break
59
-
60
- processed.append({"tokens": words, "labels": labels, "bboxes": clamped_bboxes})
61
-
62
- with open(output_path, "w", encoding="utf-8") as f:
63
- json.dump(processed, f, indent=2, ensure_ascii=False)
64
- return output_path
65
-
66
- # -------------------------
67
- # Step 1.5: Augmentation
68
- # -------------------------
69
- def translate_bbox(bbox, shift_x, shift_y):
70
- x_min, y_min, x_max, y_max = bbox
71
- new_x_min = max(0, min(x_min + shift_x, 1000))
72
- new_y_min = max(0, min(y_min + shift_y, 1000))
73
- new_x_max = max(0, min(x_max + shift_x, 1000))
74
- new_y_max = max(0, min(y_max + shift_y, 1000))
75
- return [new_x_min, new_y_min, new_x_max, new_y_max]
76
-
77
- def augment_sample(sample):
78
- shift_x = random.randint(-MAX_SHIFT, MAX_SHIFT)
79
- shift_y = random.randint(-MAX_SHIFT, MAX_SHIFT)
80
- new_sample = sample.copy()
81
- new_sample["bboxes"] = [translate_bbox(b, shift_x, shift_y) for b in sample["bboxes"]]
82
- return new_sample
83
-
84
- def augment_and_save_dataset(input_json_path, output_json_path):
85
- with open(input_json_path, 'r', encoding="utf-8") as f:
86
- training_data = json.load(f)
87
- augmented_data = []
88
- for original_sample in training_data:
89
- augmented_data.append(original_sample)
90
- for _ in range(AUGMENTATION_FACTOR):
91
- augmented_data.append(augment_sample(original_sample))
92
- with open(output_json_path, 'w', encoding="utf-8") as f:
93
- json.dump(augmented_data, f, indent=2, ensure_ascii=False)
94
- return output_json_path
95
-
96
- # -------------------------
97
- # Step 2: Dataset Class
98
- # -------------------------
99
- class LayoutDataset(Dataset):
100
- def __init__(self, json_path, tokenizer, label2id, max_len=512):
101
- with open(json_path, "r", encoding="utf-8") as f:
102
- self.data = json.load(f)
103
- self.tokenizer = tokenizer
104
- self.label2id = label2id
105
- self.max_len = max_len
106
-
107
- def __len__(self):
108
- return len(self.data)
109
-
110
- def __getitem__(self, idx):
111
- item = self.data[idx]
112
- words, bboxes, labels = item["tokens"], item["bboxes"], item["labels"]
113
- encodings = self.tokenizer(words, boxes=bboxes, padding="max_length", truncation=True, max_length=self.max_len, return_tensors="pt")
114
- word_ids = encodings.word_ids(batch_index=0)
115
- label_ids = []
116
- for word_id in word_ids:
117
- if word_id is None:
118
- label_ids.append(self.label2id["O"])
119
- else:
120
- label_ids.append(self.label2id.get(labels[word_id], self.label2id["O"]))
121
- encodings["labels"] = torch.tensor(label_ids)
122
- return {key: val.squeeze(0) for key, val in encodings.items()}
123
-
124
- # -------------------------
125
- # Step 3: Model Architecture (Non-Linear Head)
126
- # -------------------------
127
-
128
- class LayoutLMv3CRF(nn.Module):
129
- def __init__(self, num_labels):
130
- super().__init__()
131
- # Initializing from scratch (Base weights only)
132
- print(f"🔄 Initializing backbone from {BASE_MODEL_ID}...")
133
- self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID)
134
-
135
- hidden_size = self.layoutlm.config.hidden_size
136
-
137
- # NON-LINEAR MLP HEAD
138
- # Replacing the simple Linear layer with a deeper architecture
139
- self.classifier = nn.Sequential(
140
- nn.Linear(hidden_size, hidden_size),
141
- nn.GELU(), # Non-linear activation
142
- nn.LayerNorm(hidden_size), # Stability for training from scratch
143
- nn.Dropout(0.1),
144
- nn.Linear(hidden_size, num_labels)
145
- )
146
-
147
- self.crf = CRF(num_labels)
148
-
149
- def forward(self, input_ids, bbox, attention_mask, labels=None):
150
- outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
151
- sequence_output = outputs.last_hidden_state
152
-
153
- # Pass through the new non-linear head
154
- emissions = self.classifier(sequence_output)
155
-
156
- if labels is not None:
157
- log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool())
158
- return -log_likelihood.mean()
159
- else:
160
- return self.crf.viterbi_decode(emissions, mask=attention_mask.bool())
161
-
162
- # -------------------------
163
- # Step 4: Training + Evaluation
164
- # -------------------------
165
- def train_one_epoch(model, dataloader, optimizer, device):
166
- model.train()
167
- total_loss = 0
168
- for batch in tqdm(dataloader, desc="Training"):
169
- batch = {k: v.to(device) for k, v in batch.items()}
170
- labels = batch.pop("labels")
171
- optimizer.zero_grad()
172
- loss = model(**batch, labels=labels)
173
- loss.backward()
174
- optimizer.step()
175
- total_loss += loss.item()
176
- return total_loss / len(dataloader)
177
-
178
- def evaluate(model, dataloader, device, id2label):
179
- model.eval()
180
- all_preds, all_labels = [], []
181
- with torch.no_grad():
182
- for batch in tqdm(dataloader, desc="Evaluating"):
183
- batch = {k: v.to(device) for k, v in batch.items()}
184
- labels = batch.pop("labels").cpu().numpy()
185
- preds = model(**batch)
186
- for p, l, mask in zip(preds, labels, batch["attention_mask"].cpu().numpy()):
187
- valid = mask == 1
188
- l_valid = l[valid].tolist()
189
- all_labels.extend(l_valid)
190
- all_preds.extend(p[:len(l_valid)])
191
- precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro", zero_division=0)
192
- return precision, recall, f1
193
-
194
- # -------------------------
195
- # Step 5: Main Execution
196
- # -------------------------
197
- def main(args):
198
- labels = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-SECTION_HEADING", "I-SECTION_HEADING", "B-PASSAGE", "I-PASSAGE"]
199
- label2id = {l: i for i, l in enumerate(labels)}
200
- id2label = {i: l for l, i in label2id.items()}
201
-
202
- TEMP_DIR = "temp_intermediate_files"
203
- os.makedirs(TEMP_DIR, exist_ok=True)
204
-
205
- # 1. Preprocess & Augment
206
- initial_json = os.path.join(TEMP_DIR, "data_bio.json")
207
- preprocess_labelstudio(args.input, initial_json)
208
- augmented_json = os.path.join(TEMP_DIR, "data_aug.json")
209
- final_data_path = augment_and_save_dataset(initial_json, augmented_json)
210
-
211
- # 2. Setup Data
212
- tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID)
213
- dataset = LayoutDataset(final_data_path, tokenizer, label2id, max_len=args.max_len)
214
- val_size = int(0.2 * len(dataset))
215
- train_dataset, val_dataset = random_split(dataset, [len(dataset) - val_size, val_size])
216
-
217
- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
218
- val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
219
-
220
- # 3. Model
221
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
222
- model = LayoutLMv3CRF(num_labels=len(labels)).to(device)
223
- optimizer = AdamW(model.parameters(), lr=args.lr)
224
-
225
- # 4. Loop
226
- for epoch in range(args.epochs):
227
- loss = train_one_epoch(model, train_loader, optimizer, device)
228
- p, r, f1 = evaluate(model, val_loader, device, id2label)
229
- print(f"Epoch {epoch+1} | Loss: {loss:.4f} | F1: {f1:.3f}")
230
-
231
- ckpt_path = "checkpoints/layoutlmv3_nonlinear_scratch.pth"
232
- os.makedirs("checkpoints", exist_ok=True)
233
- torch.save(model.state_dict(), ckpt_path)
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  if __name__ == "__main__":
236
- parser = argparse.ArgumentParser()
237
- parser.add_argument("--mode", type=str, default="train")
238
- parser.add_argument("--input", type=str, required=True)
239
- parser.add_argument("--batch_size", type=int, default=4)
240
- parser.add_argument("--epochs", type=int, default=10) # Increased for scratch training
241
- parser.add_argument("--lr", type=float, default=2e-5)
242
- parser.add_argument("--max_len", type=int, default=512)
243
- args = parser.parse_args()
244
- main(args)
 
1
+ import gradio as gr
2
+ import fitz # PyMuPDF
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.vectorstores import FAISS
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # --- Backend Logic ---
9
+
10
+ class VectorSystem:
11
+ def __init__(self):
12
+ self.vector_store = None
13
+ # Use a lightweight CPU-friendly model
14
+ self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
15
+
16
+ def process_pdf(self, file_obj):
17
+ """Extracts text from PDF and builds the Vector Index"""
18
+ if file_obj is None:
19
+ return "No file uploaded."
20
+
21
+ try:
22
+ # 1. Extract Text
23
+ doc = fitz.open(file_obj.name)
24
+ text = ""
25
+ for page in doc:
26
+ text += page.get_text()
27
+
28
+ # 2. Split Text into Chunks
29
+ text_splitter = RecursiveCharacterTextSplitter(
30
+ chunk_size=800,
31
+ chunk_overlap=150,
32
+ separators=["\n\n", "\n", ".", " ", ""]
33
+ )
34
+ chunks = text_splitter.split_text(text)
35
+
36
+ if not chunks:
37
+ return "Could not extract text. Is the PDF scanned images?"
38
+
39
+ # 3. Build Vector Index (FAISS)
40
+ self.vector_store = FAISS.from_texts(chunks, self.embeddings)
41
+
42
+ return f"✅ Success! Indexed {len(chunks)} text chunks from the PDF."
43
+
44
+ except Exception as e:
45
+ return f"Error processing PDF: {str(e)}"
46
+
47
+ def retrieve_evidence(self, question, student_answer):
48
+ """Finds relevant text chunks based on the Question"""
49
+ if not self.vector_store:
50
+ return "⚠️ Please upload and process a PDF first."
51
+
52
+ if not question:
53
+ return "⚠️ Please enter a Question."
54
+
55
+ # We search primarily using the Question to find the 'Ground Truth' in the text.
56
+ # You could also concatenate (Question + " " + Answer) for a broader search.
57
+ docs = self.vector_store.similarity_search(question, k=3)
58
+
59
+ # Format the output
60
+ output_text = "### 🔍 Relevant Context Found:\n\n"
61
+ for i, doc in enumerate(docs):
62
+ output_text += f"**Chunk {i+1}:**\n> {doc.page_content}\n\n"
63
+
64
+ output_text += "---\n*These are the most relevant segments to grade the answer against.*"
65
+ return output_text
66
+
67
+ # Initialize System
68
+ system = VectorSystem()
69
+
70
+ # --- Gradio UI ---
71
+
72
+ with gr.Blocks(title="EduGenius Context Retriever") as demo:
73
+ gr.Markdown("# 🎓 EduGenius: PDF Context Retriever")
74
+ gr.Markdown("Upload a chapter, ask a question, and see exactly which part of the text proves the answer right or wrong.")
75
+
76
+ with gr.Row():
77
+ with gr.Column(scale=1):
78
+ # Step 1: Upload
79
+ pdf_input = gr.File(label="1. Upload PDF Chapter", file_types=[".pdf"])
80
+ upload_btn = gr.Button("Process PDF", variant="primary")
81
+ upload_status = gr.Textbox(label="Status", interactve=False)
82
+
83
+ with gr.Column(scale=2):
84
+ # Step 2: Query
85
+ question_input = gr.Textbox(label="2. Question", placeholder="e.g., What causes the chemical reaction?")
86
+ answer_input = gr.Textbox(label="Student Answer (Optional Context)", placeholder="e.g., The heat causes it...")
87
+ search_btn = gr.Button("Find Relevant Evidence", variant="secondary")
88
+
89
+ # Output
90
+ evidence_output = gr.Markdown(label="Relevant Text Chunks")
91
+
92
+ # Event Handlers
93
+ upload_btn.click(
94
+ fn=system.process_pdf,
95
+ inputs=[pdf_input],
96
+ outputs=[upload_status]
97
+ )
98
+
99
+ search_btn.click(
100
+ fn=system.retrieve_evidence,
101
+ inputs=[question_input, answer_input],
102
+ outputs=[evidence_output]
103
+ )
104
+
105
+ # Launch
106
  if __name__ == "__main__":
107
+ demo.launch()