heerjtdev commited on
Commit
95abb5a
Β·
verified Β·
1 Parent(s): 72e0c96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +403 -313
app.py CHANGED
@@ -1,325 +1,415 @@
1
- # import gradio as gr
2
- # print("GRADIO VERSION:", gr.__version__)
3
- # import json
4
- # import os
5
- # import tempfile
6
- # from pathlib import Path
7
-
8
- # # NOTE: You must ensure that 'working_yolo_pipeline.py' exists
9
- # # and defines the following items correctly:
10
- # from working_yolo_pipeline import run_document_pipeline, DEFAULT_LAYOUTLMV3_MODEL_PATH, WEIGHTS_PATH
11
- # # Since I don't have this file, I am assuming the imports are correct.
12
-
13
- # # Define placeholders for assumed constants if the pipeline file isn't present
14
- # # You should replace these with your actual definitions if they are missing
15
- # try:
16
- # from working_yolo_pipeline import run_document_pipeline, DEFAULT_LAYOUTLMV3_MODEL_PATH, WEIGHTS_PATH
17
- # except ImportError:
18
- # print("Warning: 'working_yolo_pipeline.py' not found. Using dummy paths.")
19
- # def run_document_pipeline(*args):
20
- # return {"error": "Placeholder pipeline function called."}
21
- # DEFAULT_LAYOUTLMV3_MODEL_PATH = "./models/layoutlmv3_model"
22
- # WEIGHTS_PATH = "./weights/yolo_weights.pt"
23
-
24
-
25
- # def process_pdf(pdf_file, layoutlmv3_model_path=None):
26
- # """
27
- # Wrapper function for Gradio interface.
28
-
29
- # Args:
30
- # pdf_file: Gradio UploadButton file object
31
- # layoutlmv3_model_path: Optional custom model path
32
-
33
- # Returns:
34
- # Tuple of (JSON string, download file path)
35
- # """
36
- # if pdf_file is None:
37
- # return "❌ Error: No PDF file uploaded.", None
38
-
39
- # # Use default model path if not provided
40
- # if not layoutlmv3_model_path:
41
- # layoutlmv3_model_path = DEFAULT_LAYOUTLMV3_MODEL_PATH
42
-
43
- # # Verify model and weights exist
44
- # if not os.path.exists(layoutlmv3_model_path):
45
- # return f"❌ Error: LayoutLMv3 model not found at {layoutlmv3_model_path}", None
46
-
47
- # if not os.path.exists(WEIGHTS_PATH):
48
- # return f"❌ Error: YOLO weights not found at {WEIGHTS_PATH}", None
49
-
50
- # try:
51
- # # Get the uploaded PDF path
52
- # pdf_path = pdf_file.name
53
-
54
- # # Run the pipeline
55
- # result = run_document_pipeline(pdf_path, layoutlmv3_model_path, 'label_studio_import.json')
56
-
57
- # if result is None:
58
- # return "❌ Error: Pipeline failed to process the PDF. Check console for details.", None
59
-
60
- # # Create a temporary file for download
61
- # output_filename = f"{Path(pdf_path).stem}_analysis.json"
62
- # temp_output = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json', prefix='analysis_')
63
-
64
- # # Dump results to the temporary file
65
- # with open(temp_output.name, 'w', encoding='utf-8') as f:
66
- # json.dump(result, f, indent=2, ensure_ascii=False)
67
-
68
- # # Format JSON for display
69
- # json_display = json.dumps(result, indent=2, ensure_ascii=False)
70
-
71
- # return json_display, temp_output.name
72
-
73
- # except Exception as e:
74
- # return f"❌ Error during processing: {str(e)}", None
75
-
76
-
77
- # # Create Gradio interface
78
- # # FIX APPLIED: Removed 'theme=gr.themes.Soft()' which caused the TypeError
79
- # with gr.Blocks(title="Document Analysis Pipeline") as demo:
80
- # gr.Markdown("""
81
- # # πŸ“„ Document Analysis Pipeline
82
-
83
- # Upload a PDF document to extract structured data including questions, options, answers, passages, and embedded images.
84
-
85
- # **Pipeline Steps:**
86
- # 1. πŸ” YOLO/OCR Preprocessing (word extraction + figure/equation detection)
87
- # 2. πŸ€– LayoutLMv3 Inference (BIO tagging)
88
- # 3. πŸ“Š Structured JSON Decoding
89
- # 4. πŸ–ΌοΈ Base64 Image Embedding
90
- # """)
91
-
92
- # with gr.Row():
93
- # with gr.Column(scale=1):
94
- # pdf_input = gr.File(
95
- # label="Upload PDF Document",
96
- # file_types=[".pdf"],
97
- # type="filepath"
98
- # )
99
-
100
- # model_path_input = gr.Textbox(
101
- # label="LayoutLMv3 Model Path (optional)",
102
- # placeholder=DEFAULT_LAYOUTLMV3_MODEL_PATH,
103
- # value=DEFAULT_LAYOUTLMV3_MODEL_PATH,
104
- # interactive=True
105
- # )
106
-
107
- # process_btn = gr.Button("πŸš€ Process Document", variant="primary", size="lg")
108
-
109
- # gr.Markdown("""
110
- # ### ℹ️ Notes:
111
- # - Processing may take several minutes depending on PDF size
112
- # - Figures and equations will be extracted and embedded as Base64
113
- # - The output JSON includes structured questions, options, and answers
114
- # """)
115
-
116
- # with gr.Column(scale=2):
117
- # json_output = gr.Code(
118
- # label="Structured JSON Output",
119
- # language="json",
120
- # lines=25
121
- # )
122
-
123
- # download_output = gr.File(
124
- # label="Download Full JSON",
125
- # interactive=False
126
- # )
127
-
128
- # # Status/Examples section
129
- # with gr.Row():
130
- # gr.Markdown("""
131
- # ### πŸ“‹ Output Format
132
- # The pipeline generates JSON with the following structure:
133
- # - **Questions**: Extracted question text
134
- # - **Options**: Multiple choice options (A, B, C, D, etc.)
135
- # - **Answers**: Correct answer(s)
136
- # - **Passages**: Associated reading passages
137
- # - **Images**: Base64-encoded figures and equations (embedded with keys like `figure1`, `equation2`)
138
- # """)
139
-
140
- # # Connect the button to the processing function
141
- # process_btn.click(
142
- # fn=process_pdf,
143
- # inputs=[pdf_input, model_path_input],
144
- # outputs=[json_output, download_output],
145
- # api_name="process_document"
146
- # )
147
-
148
- # # Example section (optional - add example PDFs if available)
149
- # # gr.Examples(
150
- # # examples=[
151
- # # ["examples/sample1.pdf"],
152
- # # ["examples/sample2.pdf"],
153
- # # ],
154
- # # inputs=pdf_input,
155
- # # )
156
-
157
- # # Launch the app
158
- # if __name__ == "__main__":
159
- # demo.launch(
160
- # server_name="0.0.0.0",
161
- # server_port=7860,
162
- # share=False,
163
- # show_error=True
164
- # )
165
-
166
-
167
-
168
-
169
-
170
  import gradio as gr
171
- print("GRADIO VERSION:", gr.__version__)
172
- import json
173
  import os
174
- import tempfile
175
- from pathlib import Path
176
-
177
- # ==============================
178
- # WRITE CUSTOM CSS FOR FONTS
179
- # ==============================
180
-
181
- # CUSTOM_CSS = """
182
- # @font-face {
183
- # font-family: 'NotoSansMath';
184
- # src: url('./NotoSansMath-Regular.ttf') format('truetype');
185
- # font-weight: normal;
186
- # font-style: normal;
187
- # }
188
-
189
- # html, body, * {
190
- # font-family: 'NotoSansMath', sans-serif !important;
191
- # }
192
- # """
193
-
194
- # # Optionally write the CSS file if needed (not required for inline css)
195
- # if not os.path.exists("custom.css"):
196
- # with open("custom.css", "w") as f:
197
- # f.write(CUSTOM_CSS)
198
- # ==============================
199
-
200
- try:
201
- from working_yolo_pipeline import run_document_pipeline, DEFAULT_LAYOUTLMV3_MODEL_PATH, WEIGHTS_PATH
202
- except ImportError:
203
- print("Warning: 'working_yolo_pipeline.py' not found. Using dummy paths.")
204
- def run_document_pipeline(*args):
205
- return {"error": "Placeholder pipeline function called."}
206
- DEFAULT_LAYOUTLMV3_MODEL_PATH = "./models/layoutlmv3_model"
207
- WEIGHTS_PATH = "./weights/yolo_weights.pt"
208
-
209
-
210
- def process_pdf(pdf_file, layoutlmv3_model_path=None):
211
- if pdf_file is None:
212
- return "❌ Error: No PDF file uploaded.", None
213
-
214
- if not layoutlmv3_model_path:
215
- layoutlmv3_model_path = DEFAULT_LAYOUTLMV3_MODEL_PATH
216
-
217
- if not os.path.exists(layoutlmv3_model_path):
218
- return f"❌ Error: LayoutLMv3 model not found at {layoutlmv3_model_path}", None
219
-
220
- if not os.path.exists(WEIGHTS_PATH):
221
- return f"❌ Error: YOLO weights not found at {WEIGHTS_PATH}", None
222
-
223
- try:
224
- pdf_path = pdf_file.name
225
-
226
- result = run_document_pipeline(pdf_path, layoutlmv3_model_path, 'label_studio_import.json')
227
-
228
- if result is None:
229
- return "❌ Error: Pipeline failed to process the PDF. Check console for details.", None
230
-
231
- output_filename = f"{Path(pdf_path).stem}_analysis.json"
232
- temp_output = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json', prefix='analysis_')
233
-
234
- with open(temp_output.name, 'w', encoding='utf-8') as f:
235
- json.dump(result, f, indent=2, ensure_ascii=False)
236
-
237
- json_display = json.dumps(result, indent=2, ensure_ascii=False)
238
-
239
- return json_display, temp_output.name
240
-
241
- except Exception as e:
242
- return f"❌ Error during processing: {str(e)}", None
243
-
244
-
245
- with gr.Blocks(
246
- title="Document Analysis Pipeline"
247
- ) as demo:
248
-
249
-
250
- gr.HTML()
251
-
252
- gr.Markdown("""
253
- # πŸ“„ Document Analysis Pipeline
254
-
255
- Upload a PDF document to extract structured data including questions, options, answers, passages, and embedded images.
256
-
257
- **Pipeline Steps:**
258
- 1. πŸ” YOLO/OCR Preprocessing (word extraction + figure/equation detection)
259
- 2. πŸ€– LayoutLMv3 Inference (BIO tagging)
260
- 3. πŸ“Š Structured JSON Decoding
261
- 4. πŸ–ΌοΈ Base64 Image Embedding
262
- """)
263
-
264
- with gr.Row():
265
- with gr.Column(scale=1):
266
- pdf_input = gr.File(
267
- label="Upload PDF Document",
268
- file_types=[".pdf"],
269
- type="filepath"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  )
271
 
272
- model_path_input = gr.Textbox(
273
- label="LayoutLMv3 Model Path (optional)",
274
- placeholder=DEFAULT_LAYOUTLMV3_MODEL_PATH,
275
- value=DEFAULT_LAYOUTLMV3_MODEL_PATH,
276
- interactive=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  )
278
-
279
- process_btn = gr.Button("πŸš€ Process Document", variant="primary", size="lg")
280
-
281
- gr.Markdown("""
282
- ### ℹ️ Notes:
283
- - Processing may take several minutes depending on PDF size
284
- - Figures and equations will be extracted and embedded as Base64
285
- - The output JSON includes structured questions, options, and answers
286
- """)
287
-
288
- with gr.Column(scale=2):
289
- json_output = gr.Code(
290
- label="Structured JSON Output",
291
- language="json",
292
- lines=25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  )
294
-
295
- download_output = gr.File(
296
- label="Download Full JSON",
297
- interactive=False
 
 
 
 
 
 
 
298
  )
299
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  with gr.Row():
301
- gr.Markdown("""
302
- ### πŸ“‹ Output Format
303
- The pipeline generates JSON with the following structure:
304
- - **Questions**: Extracted question text
305
- - **Options**: Multiple choice options
306
- - **Answers**: Correct answer(s)
307
- - **Passages**: Associated reading passages
308
- - **Images**: Base64-encoded figures and equations
309
- """)
 
 
310
 
311
- process_btn.click(
312
- fn=process_pdf,
313
- inputs=[pdf_input, model_path_input],
314
- outputs=[json_output, download_output],
315
- api_name="process_document"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  )
317
-
318
 
319
  if __name__ == "__main__":
320
- demo.launch(
321
- server_name="0.0.0.0",
322
- server_port=7860,
323
- share=False,
324
- show_error=True
325
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import fitz
3
+ import torch
4
  import os
5
+ import re
6
+ import numpy as np
7
+ from collections import Counter
8
+ import onnxruntime as ort
9
+ from onnxruntime import SessionOptions, GraphOptimizationLevel
10
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_core.embeddings import Embeddings
13
+ from transformers import AutoTokenizer
14
+ from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForCausalLM
15
+ from huggingface_hub import snapshot_download
16
+ from sentence_transformers import SentenceTransformer # Add this for cross-encoder
17
+
18
+ PROVIDERS = ["CPUExecutionProvider"]
19
+
20
+ # ---------------------------------------------------------
21
+ # 1. EMBEDDINGS (Your existing code - good)
22
+ # ---------------------------------------------------------
23
+ class OnnxBgeEmbeddings(Embeddings):
24
+ def __init__(self):
25
+ model_name = "Xenova/bge-small-en-v1.5"
26
+ print(f"πŸ”„ Loading Embeddings: {model_name}...")
27
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ self.model = ORTModelForFeatureExtraction.from_pretrained(
29
+ model_name, export=False, provider=PROVIDERS[0]
30
+ )
31
+
32
+ def _process_batch(self, texts):
33
+ inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
34
+ with torch.no_grad():
35
+ outputs = self.model(**inputs)
36
+ embeddings = outputs.last_hidden_state[:, 0]
37
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
38
+ return embeddings.numpy().tolist()
39
+
40
+ def embed_documents(self, texts):
41
+ return self._process_batch(texts)
42
+
43
+ def embed_query(self, text):
44
+ return self._process_batch([text])[0]
45
+
46
+ # ---------------------------------------------------------
47
+ # 2. RULE-BASED GRADING ENGINE (NEW - No LLM needed)
48
+ # ---------------------------------------------------------
49
+ class RuleBasedGrader:
50
+ """
51
+ Extracts key concepts from context and checks student answer coverage.
52
+ Works 100% on CPU, deterministic, explainable.
53
+ """
54
+
55
+ def __init__(self):
56
+ # Load a small NER or keyword extraction model if needed
57
+ # Or use simple TF-IDF/RAKE algorithm
58
+ pass
59
+
60
+ def extract_key_concepts(self, text, top_k=10):
61
+ """
62
+ Extract key noun phrases and important terms from context.
63
+ Uses simple but effective heuristics.
64
+ """
65
+ # Clean text
66
+ text = re.sub(r'[^\w\s]', ' ', text.lower())
67
+ words = text.split()
68
+
69
+ # Remove stopwords
70
+ stopwords = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may', 'might', 'must', 'shall', 'can', 'need', 'dare', 'ought', 'used', 'it', 'this', 'that', 'these', 'those', 'i', 'you', 'he', 'she', 'we', 'they'}
71
+
72
+ # Get word frequencies (excluding stopwords)
73
+ words = [w for w in words if w not in stopwords and len(w) > 2]
74
+ word_freq = Counter(words)
75
+
76
+ # Get bigrams (two-word phrases)
77
+ bigrams = [f"{words[i]} {words[i+1]}" for i in range(len(words)-1)]
78
+ bigram_freq = Counter(bigrams)
79
+
80
+ # Combine unigrams and bigrams
81
+ concepts = []
82
+ for word, count in word_freq.most_common(top_k):
83
+ if count > 1: # Only include words that appear multiple times
84
+ concepts.append(word)
85
+
86
+ for bigram, count in bigram_freq.most_common(top_k//2):
87
+ if count > 1:
88
+ concepts.append(bigram)
89
+
90
+ return list(set(concepts))[:top_k] # Remove duplicates, limit to top_k
91
+
92
+ def check_concept_coverage(self, student_answer, key_concepts):
93
+ """
94
+ Check which key concepts from context appear in student answer.
95
+ Returns coverage score and missing concepts.
96
+ """
97
+ student_lower = student_answer.lower()
98
+ found_concepts = []
99
+ missing_concepts = []
100
+
101
+ for concept in key_concepts:
102
+ # Check for exact match or partial match
103
+ if concept in student_lower:
104
+ found_concepts.append(concept)
105
+ else:
106
+ # Check for word stems (e.g., "running" matches "run")
107
+ concept_words = concept.split()
108
+ if all(any(word in student_lower for word in [cw, cw+'s', cw+'es', cw+'ed', cw+'ing']) for cw in concept_words):
109
+ found_concepts.append(concept)
110
+ else:
111
+ missing_concepts.append(concept)
112
+
113
+ coverage = len(found_concepts) / len(key_concepts) if key_concepts else 0
114
+ return coverage, found_concepts, missing_concepts
115
+
116
+ def detect_contradictions(self, context, student_answer):
117
+ """
118
+ Simple contradiction detection using negation patterns.
119
+ """
120
+ context_lower = context.lower()
121
+ answer_lower = student_answer.lower()
122
+
123
+ # Common negation patterns
124
+ negation_words = ['not', 'no', 'never', 'none', 'nothing', 'nobody', 'neither', 'nowhere', 'hardly', 'scarcely', 'barely', "doesn't", "isn't", "wasn't", "shouldn't", "wouldn't", "couldn't", "can't", "don't", "didn't", "hasn't", "haven't", "hadn't", "won't"]
125
+
126
+ contradictions = []
127
+
128
+ # Extract sentences from context that contain key facts
129
+ context_sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 10]
130
+
131
+ for sent in context_sentences:
132
+ sent_lower = sent.lower()
133
+ # Check if student says opposite
134
+ for neg in negation_words:
135
+ if neg in sent_lower:
136
+ # Context has negation, check if student affirms
137
+ positive_version = sent_lower.replace(neg, '').strip()
138
+ if any(word in answer_lower for word in positive_version.split()[:5]):
139
+ contradictions.append(f"Context says: '{sent}' but student contradicts this")
140
+ else:
141
+ # Context is positive, check if student negates
142
+ # This is harder - would need semantic understanding
143
+ pass
144
+
145
+ return contradictions
146
+
147
+ def calculate_semantic_similarity(self, context, student_answer, embeddings_model):
148
+ """
149
+ Use embeddings to calculate semantic similarity.
150
+ """
151
+ context_emb = embeddings_model.embed_query(context)
152
+ answer_emb = embeddings_model.embed_query(student_answer)
153
+
154
+ # Cosine similarity
155
+ similarity = np.dot(context_emb, answer_emb) / (np.linalg.norm(context_emb) * np.linalg.norm(answer_emb))
156
+ return float(similarity)
157
+
158
+ def grade(self, context, question, student_answer, max_marks, embeddings_model):
159
+ """
160
+ Main grading function combining multiple signals.
161
+ """
162
+ # 1. Extract key concepts from context
163
+ key_concepts = self.extract_key_concepts(context)
164
+
165
+ # 2. Check concept coverage
166
+ coverage, found, missing = self.check_concept_coverage(student_answer, key_concepts)
167
+
168
+ # 3. Check for contradictions
169
+ contradictions = self.detect_contradictions(context, student_answer)
170
+
171
+ # 4. Calculate semantic similarity
172
+ semantic_sim = self.calculate_semantic_similarity(context, student_answer, embeddings_model)
173
+
174
+ # 5. Calculate final score
175
+ # Weight: 60% concept coverage, 40% semantic similarity
176
+ # Penalty for contradictions: -50% per contradiction
177
+
178
+ base_score = (coverage * 0.6 + semantic_sim * 0.4) * max_marks
179
+
180
+ # Apply contradiction penalties
181
+ contradiction_penalty = len(contradictions) * (max_marks * 0.5)
182
+ final_score = max(0, base_score - contradiction_penalty)
183
+
184
+ # Generate feedback
185
+ feedback = f"""
186
+ **Grading Analysis:**
187
+
188
+ **Key Concepts Found ({len(found)}/{len(key_concepts)}):** {', '.join(found) if found else 'None'}
189
+ **Key Concepts Missing:** {', '.join(missing) if missing else 'None'}
190
+
191
+ **Concept Coverage:** {coverage:.1%}
192
+ **Semantic Similarity:** {semantic_sim:.1%}
193
+
194
+ **Contradictions Detected:** {len(contradictions)}
195
+ {chr(10).join(['- ' + c for c in contradictions]) if contradictions else 'None'}
196
+
197
+ **Calculation:** ({coverage:.1%} Γ— 0.6 + {semantic_sim:.1%} Γ— 0.4) Γ— {max_marks} - {contradiction_penalty:.1f} penalty = **{final_score:.1f}/{max_marks}**
198
+ """
199
+
200
+ return final_score, feedback
201
+
202
+ # ---------------------------------------------------------
203
+ # 3. LLM EVALUATOR (Fallback for edge cases)
204
+ # ---------------------------------------------------------
205
+ class LLMEvaluator:
206
+ def __init__(self):
207
+ self.repo_id = "onnx-community/Qwen2.5-0.5B-Instruct"
208
+ self.local_dir = "onnx_qwen_local"
209
+
210
+ if not os.path.exists(self.local_dir):
211
+ snapshot_download(
212
+ repo_id=self.repo_id,
213
+ local_dir=self.local_dir,
214
+ allow_patterns=["config.json", "generation_config.json", "tokenizer*", "special_tokens_map.json", "*.jinja", "onnx/model_fp16.onnx*"]
215
  )
216
 
217
+ self.tokenizer = AutoTokenizer.from_pretrained(self.local_dir)
218
+
219
+ sess_options = SessionOptions()
220
+ sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
221
+
222
+ self.model = ORTModelForCausalLM.from_pretrained(
223
+ self.local_dir,
224
+ subfolder="onnx",
225
+ file_name="model_fp16.onnx",
226
+ use_cache=True,
227
+ use_io_binding=False,
228
+ provider=PROVIDERS[0],
229
+ session_options=sess_options
230
+ )
231
+
232
+ def evaluate(self, context, question, student_answer, max_marks, rule_based_score):
233
+ """
234
+ Use LLM only for ambiguous cases or to verify edge cases.
235
+ Simplified prompt for 0.5B model.
236
+ """
237
+ # If rule-based gave clear 0 or max, don't bother with LLM
238
+ if rule_based_score == 0:
239
+ return "Score: 0/{max_marks}\nFeedback: Answer contains significant errors or contradictions with the reference text."
240
+ if rule_based_score == max_marks:
241
+ return "Score: {max_marks}/{max_marks}\nFeedback: Excellent answer that fully covers the reference material."
242
+
243
+ # Otherwise, use LLM for nuanced cases
244
+ prompt = f"""Grade this answer based ONLY on the context provided.
245
+
246
+ Context: {context[:500]}
247
+ Question: {question}
248
+ Student Answer: {student_answer}
249
+
250
+ Rules:
251
+ 1. Give 0 if answer contradicts context or adds outside information
252
+ 2. Give full marks only if answer matches context exactly
253
+ 3. Give partial marks for partial matches
254
+
255
+ Output exactly:
256
+ Score: X/{max_marks}
257
+ Feedback: One sentence explanation"""
258
+
259
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
260
+
261
+ with torch.no_grad():
262
+ outputs = self.model.generate(
263
+ **inputs,
264
+ max_new_tokens=50,
265
+ temperature=0.1,
266
+ do_sample=False,
267
+ pad_token_id=self.tokenizer.eos_token_id
268
  )
269
+
270
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
271
+ # Extract just the generated part (after the prompt)
272
+ response = response[len(self.tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)):]
273
+ return response.strip()
274
+
275
+ # ---------------------------------------------------------
276
+ # 4. MAIN APPLICATION
277
+ # ---------------------------------------------------------
278
+ class VectorSystem:
279
+ def __init__(self):
280
+ self.vector_store = None
281
+ self.embeddings = OnnxBgeEmbeddings()
282
+ self.rule_grader = RuleBasedGrader()
283
+ self.llm = LLMEvaluator()
284
+ self.all_chunks = []
285
+ self.total_chunks = 0
286
+
287
+ def process_content(self, file_obj, raw_text):
288
+ has_file = file_obj is not None
289
+ has_text = raw_text is not None and len(raw_text.strip()) > 0
290
+
291
+ if has_file and has_text:
292
+ return "❌ Error: Provide EITHER file OR text, not both."
293
+
294
+ if not has_file and not has_text:
295
+ return "⚠️ No content provided."
296
+
297
+ try:
298
+ text = ""
299
+ if has_file:
300
+ if file_obj.name.endswith('.pdf'):
301
+ doc = fitz.open(file_obj.name)
302
+ for page in doc:
303
+ text += page.get_text()
304
+ elif file_obj.name.endswith('.txt'):
305
+ with open(file_obj.name, 'r', encoding='utf-8') as f:
306
+ text = f.read()
307
+ else:
308
+ return "❌ Only .pdf and .txt supported."
309
+ else:
310
+ text = raw_text
311
+
312
+ # Larger chunks for better context
313
+ text_splitter = RecursiveCharacterTextSplitter(
314
+ chunk_size=1000,
315
+ chunk_overlap=200,
316
+ separators=["\n\n", "\n", ". ", " ", ""]
317
  )
318
+ self.all_chunks = text_splitter.split_text(text)
319
+ self.total_chunks = len(self.all_chunks)
320
+
321
+ if not self.all_chunks:
322
+ return "Content empty."
323
+
324
+ metadatas = [{"id": i} for i in range(self.total_chunks)]
325
+ self.vector_store = FAISS.from_texts(
326
+ self.all_chunks,
327
+ self.embeddings,
328
+ metadatas=metadatas
329
  )
330
+
331
+ return f"βœ… Indexed {self.total_chunks} chunks."
332
+ except Exception as e:
333
+ return f"Error: {str(e)}"
334
+
335
+ def process_query(self, question, student_answer, max_marks):
336
+ if not self.vector_store:
337
+ return "⚠️ Upload content first.", ""
338
+ if not question:
339
+ return "⚠️ Enter a question.", ""
340
+ if not student_answer:
341
+ return "⚠️ Enter a student answer.", ""
342
+
343
+ # Retrieve relevant context
344
+ results = self.vector_store.similarity_search_with_score(question, k=2)
345
+
346
+ # Combine top 2 chunks for better context
347
+ context_parts = []
348
+ for doc, score in results:
349
+ context_parts.append(self.all_chunks[doc.metadata['id']])
350
+
351
+ expanded_context = "\n".join(context_parts)
352
+
353
+ # Use rule-based grading (fast, deterministic)
354
+ score, feedback = self.rule_grader.grade(
355
+ expanded_context,
356
+ question,
357
+ student_answer,
358
+ max_marks,
359
+ self.embeddings
360
+ )
361
+
362
+ # Optional: Use LLM for ambiguous cases (score between 20-80%)
363
+ # Uncomment if you want LLM verification
364
+ # if 0.2 < (score/max_marks) < 0.8:
365
+ # llm_feedback = self.llm.evaluate(expanded_context, question, student_answer, max_marks, score)
366
+ # feedback += f"\n\n**LLM Verification:**\n{llm_feedback}"
367
+
368
+ evidence_display = f"### πŸ“š Context Used:\n{expanded_context[:800]}..."
369
+ grade_display = f"### πŸ“ Grade: {score:.1f}/{max_marks}\n\n{feedback}"
370
+
371
+ return evidence_display, grade_display
372
+
373
+ # Initialize and launch
374
+ system = VectorSystem()
375
+
376
+ with gr.Blocks(title="EduGenius AI Grader") as demo:
377
+ gr.Markdown("# ⚑ EduGenius: CPU Optimized RAG")
378
+ gr.Markdown("Hybrid Rule-Based + LLM Grading (ONNX Optimized)")
379
+
380
  with gr.Row():
381
+ with gr.Column(scale=1):
382
+ gr.Markdown("### Source Input")
383
+ pdf_input = gr.File(label="Upload Chapter (PDF/TXT)")
384
+ gr.Markdown("**OR**")
385
+ text_input = gr.Textbox(
386
+ label="Paste Context",
387
+ placeholder="Paste text here...",
388
+ lines=5
389
+ )
390
+ upload_btn = gr.Button("Index Content", variant="primary")
391
+ status_msg = gr.Textbox(label="Status", interactive=False)
392
 
393
+ with gr.Column(scale=2):
394
+ q_input = gr.Textbox(label="Question", scale=2)
395
+ max_marks = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Max Marks")
396
+ a_input = gr.TextArea(label="Student Answer", lines=5)
397
+ run_btn = gr.Button("Retrieve & Grade", variant="secondary")
398
+
399
+ with gr.Row():
400
+ evidence_box = gr.Markdown()
401
+ grade_box = gr.Markdown()
402
+
403
+ upload_btn.click(
404
+ system.process_content,
405
+ inputs=[pdf_input, text_input],
406
+ outputs=[status_msg]
407
+ )
408
+ run_btn.click(
409
+ system.process_query,
410
+ inputs=[q_input, a_input, max_marks],
411
+ outputs=[evidence_box, grade_box]
412
  )
 
413
 
414
  if __name__ == "__main__":
415
+ demo.launch()