heerjtdev commited on
Commit
ccebae5
·
verified ·
1 Parent(s): bbf9bfd

Update app_prince.py

Browse files
Files changed (1) hide show
  1. app_prince.py +169 -942
app_prince.py CHANGED
@@ -1,942 +1,169 @@
1
- import uuid
2
- from collections import Counter
3
- import pickle
4
- from typing import Tuple
5
- from flask import Flask, request, jsonify, Response
6
- from flask_cors import CORS
7
- import os
8
- import json
9
- from datetime import datetime
10
- import random
11
- # from gradio_api import call_layoutlm_api
12
- from gradio_api_prince import call_yolo_api
13
-
14
- """
15
- ===========================================================
16
-
17
-
18
- MODEL OPTIONS
19
-
20
-
21
- ===========================================================
22
- """
23
- app = Flask(__name__)
24
- CORS(app)
25
-
26
- from collections import OrderedDict
27
-
28
- """
29
- ====================================================================
30
-
31
- Helper Functions
32
-
33
- ====================================================================
34
- """
35
-
36
- from vector_db_prince import store_mcqs, fetch_mcqs, fetch_random_mcqs, store_test_session, fetch_test_by_testId, \
37
- test_sessions_by_userId, store_submitted_test, submitted_tests_by_userId, add_single_question, \
38
- update_single_question, delete_single_question, store_mcqs_for_manual_creation, delete_mcq_bank, \
39
- delete_submitted_test_by_id, delete_test_session_by_id, update_test_session, update_question_bank_metadata, \
40
- fetch_submitted_test_by_testId, delete_submitted_test_attempt, update_answer_flag_in_qdrant
41
- from werkzeug.utils import secure_filename
42
-
43
-
44
- def format_mcq(mcq):
45
- return {
46
- "question": mcq.get("question") or mcq.get("ques") or mcq.get("q"),
47
- "noise": mcq.get("noise"),
48
- "image": mcq.get("image") or mcq.get("img"),
49
- "options": mcq.get("options") or mcq.get("opts"),
50
- "answer": mcq.get("answer") or mcq.get("ans") or mcq.get("correct")
51
- }
52
- # ===================================================
53
- # uncomment the text below to use gemini pipeline instead of the pre-trained model
54
- # ===================================================
55
-
56
- class Vocab:
57
- """Vocabulary class for serialization and lookup."""
58
-
59
- def __init__(self, min_freq=1, unk_token="<UNK>", pad_token="<PAD>"):
60
- self.min_freq = min_freq
61
- self.unk_token = unk_token
62
- self.pad_token = pad_token
63
- self.freq = Counter()
64
- self.itos = []
65
- self.stoi = {}
66
-
67
- def __len__(self):
68
- return len(self.itos)
69
-
70
- def __getitem__(self, token: str) -> int:
71
- """Allows lookup using word_vocab[token]. Returns UNK index if token is not found."""
72
- # Returns the index of the token, or the index of <UNK> if not found.
73
- return self.stoi.get(token, self.stoi[self.unk_token])
74
-
75
- # Methods for pickle serialization
76
- def __getstate__(self):
77
- return {
78
- 'min_freq': self.min_freq,
79
- 'unk_token': self.unk_token,
80
- 'pad_token': self.pad_token,
81
- 'itos': self.itos,
82
- 'stoi': self.stoi,
83
- }
84
-
85
- def __setstate__(self, state):
86
- self.min_freq = state['min_freq']
87
- self.unk_token = state['unk_token']
88
- self.pad_token = state['pad_token']
89
- self.itos = state['itos']
90
- self.stoi = state['stoi']
91
- self.freq = Counter()
92
-
93
-
94
- def load_vocabs(path: str) -> Tuple[Vocab, Vocab]:
95
- """Loads word and character vocabularies from a pickle file."""
96
- try:
97
- with open(path, "rb") as f:
98
- word_vocab, char_vocab = pickle.load(f)
99
-
100
- if len(word_vocab) <= 2 or len(char_vocab) <= 2:
101
- raise IndexError("Vocabulary file loaded but sizes are suspiciously small.")
102
-
103
- return word_vocab, char_vocab
104
- except FileNotFoundError:
105
- raise FileNotFoundError(f"Vocab file not found at {path}. Please check the path.")
106
- except Exception as e:
107
- raise RuntimeError(f"Error loading vocabs from {path}: {e}")
108
-
109
-
110
- UPLOAD_FOLDER = "/tmp"
111
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
112
-
113
-
114
- # ==============================
115
- #
116
- # API
117
- #
118
- # ===============================
119
-
120
- @app.route("/create_question_bank", methods=["POST"])
121
- def upload_pdf():
122
- print("\n[START] /create_question_bank request received")
123
-
124
- # 1. Validate inputs
125
- user_id = request.form.get("userId")
126
- title = request.form.get("title")
127
- description = request.form.get("description")
128
- pdf_file = request.files.get("pdf")
129
-
130
- print(f"[INFO] Received form-data: userId={user_id}, title={title}, description={description}")
131
- if not pdf_file:
132
- return jsonify({"error": "PDF file not provided"}), 400
133
-
134
- if not all([user_id, title, description]):
135
- return jsonify({"error": "userId, title, description are required"}), 400
136
-
137
- # 2. Keep PDF in memory (no Drive)
138
- print("[STEP] Reading PDF into memory...")
139
- pdf_bytes = pdf_file.read()
140
- pdf_name = secure_filename(pdf_file.filename)
141
-
142
- # 3. Directly call model
143
- print("[STEP] Calling LayoutLM model directly (no Drive)...")
144
- # final_data = call_layoutlm_api(pdf_bytes, pdf_name)
145
- final_data = call_yolo_api(pdf_bytes, pdf_name)
146
- print(f"[SUCCESS] LayoutLM returned {len(final_data)} MCQs")
147
-
148
- # 4. Add index to MCQs
149
- indexed_mcqs = [
150
- {
151
- **mcq,
152
- "documentIndex": i,
153
- "questionId": str(uuid.uuid4()) # ✅ assign unique ID
154
- }
155
- for i, mcq in enumerate(final_data)
156
- ]
157
-
158
- # 5. Store in vector DB
159
- print("[STEP] Storing Question Bank in vector database...")
160
- createdAtTimestamp = datetime.now().isoformat()
161
- stored_id, all_have_answers = store_mcqs(
162
- user_id, title, description, indexed_mcqs, pdf_name, createdAtTimestamp
163
- )
164
- print(f"[SUCCESS] Stored with generatedQAId={stored_id}")
165
-
166
- print("[END] Request complete\n")
167
- return Response(
168
- json.dumps({
169
- "generatedQAId": stored_id,
170
- "userId": user_id,
171
- "fileName": pdf_name,
172
- "createdAt": createdAtTimestamp,
173
- "answerFound": all_have_answers
174
- }, ensure_ascii=False),
175
- mimetype="application/json"
176
- )
177
-
178
-
179
- @app.route("/create_question_bank_image", methods=["POST"])
180
- def upload_image():
181
- print("\n[START] /create_question_bank request received")
182
-
183
- # 1. Validate inputs
184
- user_id = request.form.get("userId")
185
- title = request.form.get("title")
186
- description = request.form.get("description")
187
- image_files = request.files.getlist("image") # ✅ multiple images
188
-
189
- print(f"[INFO] Received form-data: userId={user_id}, title={title}, description={description}")
190
- if not image_files or len(image_files) == 0:
191
- return jsonify({"error": "No image file(s) provided"}), 400
192
-
193
- if not all([user_id, title, description]):
194
- return jsonify({"error": "userId, title, description are required"}), 400
195
-
196
- all_results = []
197
-
198
- # 2. Loop through each image
199
- for idx, img_file in enumerate(image_files, start=1):
200
- print(f"[STEP] Reading image {idx}/{len(image_files)} into memory...")
201
- file_bytes = img_file.read()
202
- filename = secure_filename(img_file.filename)
203
-
204
- # 3. Directly call model for each image
205
- print(f"[STEP] Calling LayoutLM model for {filename} ...")
206
- try:
207
- result = call_yolo_api(file_bytes, filename)
208
- print(f"[SUCCESS] Model returned result for {filename}")
209
- if isinstance(result, list):
210
- all_results.extend(result)
211
- else:
212
- all_results.append(result)
213
- except Exception as e:
214
- print(f"[ERROR] Failed on {filename}: {e}")
215
-
216
- # 4. Add index to MCQs
217
- indexed_mcqs = [
218
- {**mcq, "documentIndex": i}
219
- for i, mcq in enumerate(all_results)
220
- ]
221
-
222
- # 5. Store in vector DB
223
- print("[STEP] Storing Question Bank in vector database...")
224
- createdAtTimestamp = datetime.now().isoformat()
225
- stored_id = store_mcqs(
226
- user_id, title, description, indexed_mcqs, "multiple_images.zip", createdAtTimestamp
227
- )
228
- print(f"[SUCCESS] Stored with generatedQAId={stored_id}")
229
-
230
- print("[END] Request complete\n")
231
- return Response(
232
- json.dumps({
233
- "generatedQAId": stored_id,
234
- "userId": user_id,
235
- "fileCount": len(image_files),
236
- "createdAt": createdAtTimestamp,
237
- }, ensure_ascii=False),
238
- mimetype="application/json"
239
- )
240
- @app.route("/question_bank_by_user", methods=["POST"])
241
- def paper_sets_by_userID():
242
- data = request.get_json(silent=True) or request.form.to_dict()
243
- userId = data.get("userId")
244
-
245
- mcqs_data = fetch_mcqs(userId=userId)
246
- if not mcqs_data:
247
- return jsonify({"message": "No Paper Sets found"})
248
-
249
- # FIX: Iterate through each paper set and sort its MCQs list
250
- for paper_set in mcqs_data:
251
- # Check if the 'mcqs' list exists and is iterable
252
- if paper_set.get('metadata', {}).get('mcqs'):
253
- mcqs_list = paper_set['metadata']['mcqs']
254
-
255
- # This handles older data that might have missing or None 'documentIndex' values.
256
- paper_set['metadata']['mcqs'] = sorted(
257
- mcqs_list,
258
- key=lambda x: int(x['documentIndex'])
259
- if x.get('documentIndex') is not None else float('inf')
260
- )
261
- # ===============================================
262
-
263
- return Response(
264
- json.dumps(mcqs_data, ensure_ascii=False, indent=4),
265
- mimetype="application/json"
266
- )
267
- @app.route("/question_bank_by_id", methods=["POST"])
268
- def paper_sets_by_generatedQAId():
269
- data = request.get_json(silent=True) or request.form.to_dict()
270
- generatedQAId = data.get("generatedQAId")
271
-
272
- if not generatedQAId:
273
- return jsonify({"error": "generatedQAId is required"}), 400
274
-
275
- results = fetch_mcqs(generatedQAId=generatedQAId)
276
-
277
- if not results:
278
- return jsonify({"error": "No MCQs found for the provided ID"}), 200
279
-
280
- if results and results[0].get('metadata', {}).get('mcqs'):
281
- mcqs_list = results[0]['metadata']['mcqs']
282
- # Sort by the 'documentIndex' field.
283
- # Fall back to 0 if the index is missing, though it shouldn't be.
284
- results[0]['metadata']['mcqs'] = sorted(
285
- mcqs_list,
286
- key=lambda x: x.get('documentIndex', 0)
287
- )
288
- # ===============================================
289
- # Return the full list of results, as generated by fetch_mcqs
290
- return jsonify(results)
291
-
292
-
293
- @app.route("/generate_test", methods=["POST"])
294
- def generate_test():
295
- """
296
- API to fetch MCQs by generated-qa-Id and marks (limit),
297
- and also to create a new test entry.
298
- """
299
- data = request.get_json(silent=True) or request.form
300
-
301
- generatedQAId = data.get("generatedQAId")
302
- marks = data.get("marks")
303
- userId = data.get("userId")
304
- testTitle = data.get("testTitle")
305
- totalTime = data.get("totalTime")
306
-
307
- if not generatedQAId:
308
- return jsonify({"error": "generatedQAId is required"}), 400
309
- # ... (other validation checks)
310
-
311
- try:
312
- marks = int(marks)
313
- except ValueError:
314
- return jsonify({"error": "marks must be an integer"}), 400
315
-
316
- testId = str(uuid.uuid4())
317
- createdAt = datetime.now().isoformat()
318
-
319
- # 1. Fetch random sample
320
- test_data_results = fetch_random_mcqs(generatedQAId, num_questions=marks)
321
-
322
- if not test_data_results:
323
- return jsonify({"message": "No MCQs found"}), 200
324
-
325
- mcqs_data = test_data_results[0].get("metadata", {}).get("mcqs", [])
326
-
327
- # The list mcqs_data is now in the final, random order for the test.
328
-
329
- # 2. ASSIGN NEW SEQUENTIAL INDEX (testIndex)
330
- final_mcqs_for_storage = []
331
- for i, mcq in enumerate(mcqs_data):
332
- # Assign a sequential index starting from 1 for the client/storage
333
- mcq['testIndex'] = i + 1
334
- final_mcqs_for_storage.append(mcq)
335
-
336
- # 3. Store the session using the indexed list
337
- if userId:
338
- is_stored = store_test_session(userId, testId, testTitle, totalTime, createdAt, final_mcqs_for_storage)
339
- if not is_stored:
340
- return jsonify({"error": "Failed to store test session"}), 500
341
-
342
- # 4. Return the result
343
- return jsonify({
344
- "message": "Test created and stored successfully",
345
- "userId": userId,
346
- "testId": testId,
347
- "totalTime": totalTime,
348
- "createdAt": createdAt,
349
- "questions": final_mcqs_for_storage # Return the indexed list
350
- }), 200
351
-
352
-
353
- @app.route("/combined_paperset", methods=["POST"])
354
- def combined_test():
355
- data = request.get_json(silent=True) or request.form
356
-
357
- userId = data.get("userId")
358
- testTitle = data.get("testTitle")
359
- totalTime = data.get("totalTime")
360
- total_questions = data.get("total_questions")
361
- sources = data.get("sources")
362
-
363
- # Validate required inputs
364
- if not all([userId, testTitle, totalTime, total_questions, sources]) or not isinstance(sources, list):
365
- return jsonify(
366
- {"error": "userId, testTitle, total_questions, totalTime, and a list of sources are required"}), 400
367
-
368
- try:
369
- total_questions = int(total_questions)
370
- if sum(s.get("percentage", 0) for s in sources) != 100:
371
- return jsonify({"error": "Percentages must sum to 100"}), 400
372
- except (ValueError, TypeError):
373
- return jsonify({"error": "total_questions must be an integer and percentages must be numbers"}), 400
374
-
375
- all_mcqs = []
376
-
377
- for source in sources:
378
- qa_id = source.get("generatedQAId")
379
- percentage = source.get("percentage")
380
-
381
- if not qa_id or not percentage:
382
- return jsonify({"error": "Each source must have 'generatedQAId' and 'percentage'"}), 400
383
-
384
- # Calculate the number of questions for this source
385
- num_questions = round(total_questions * (percentage / 100))
386
-
387
- # Fetch a random sample from this source
388
- # Note: fetch_random_mcqs returns a list containing a dict with metadata/mcqs
389
- mcqs_record = fetch_random_mcqs(generatedQAId=qa_id, num_questions=num_questions)
390
-
391
- if mcqs_record:
392
- # Extract the list of questions and combine them
393
- all_mcqs.extend(mcqs_record[0].get("metadata", {}).get("mcqs", []))
394
-
395
- # Shuffle the combined list of all MCQs to finalize the test order
396
- random.shuffle(all_mcqs)
397
-
398
- if not all_mcqs:
399
- return jsonify({"message": "No MCQs found for the provided IDs"}), 200
400
-
401
- # Assign a new, sequential index (testIndex) to each question
402
- final_mcqs_for_storage = []
403
- for i, mcq in enumerate(all_mcqs):
404
- # Assign a sequential index starting from 1
405
- mcq['testIndex'] = i + 1
406
- final_mcqs_for_storage.append(mcq)
407
-
408
- # Generate test metadata
409
- testId = str(uuid.uuid4())
410
- createdAt = datetime.now().isoformat()
411
-
412
- # Store the test session with the indexed list
413
- store_test_session(userId, testId, testTitle, totalTime, createdAt, final_mcqs_for_storage)
414
-
415
- return jsonify({
416
- "userId": userId,
417
- "testId": testId,
418
- "testTitle": testTitle,
419
- "totalTime": totalTime,
420
- "createdAt": createdAt,
421
- "questions": final_mcqs_for_storage # Return the correctly indexed list
422
- }), 200
423
-
424
-
425
- @app.route("/paper_set/<testId>", methods=["GET"])
426
- def testId(testId):
427
- """
428
-
429
- API to fetch a specific test session by its ID.
430
- """
431
- test_data = fetch_test_by_testId(testId)
432
- if not test_data:
433
- return jsonify({"error": "Test not found"}), 200
434
- for q in test_data.get("questions", []):
435
- q.pop("answer", None)
436
-
437
- return jsonify(test_data), 200
438
-
439
-
440
- @app.route("/paper_sets_by_user/<userId>", methods=["GET"])
441
- def test_history_by_userId(userId):
442
- test_history = test_sessions_by_userId(userId)
443
- if not test_history:
444
- return jsonify({"message": "No test sessions found"}), 200
445
-
446
- # remove answers before sending to frontend
447
- for test in test_history:
448
- for q in test.get("questions", []):
449
- q.pop("answer", None) # removes if present
450
-
451
- return jsonify(test_history), 200
452
-
453
-
454
- @app.route("/submit_test", methods=["POST"])
455
- def submit_test():
456
- """
457
- API to submit student answers, check correctness,
458
- calculate score, and store submission data.
459
- Frontend sends: userId, testId, testTitle, timeSpent, totalTime, answers[]
460
- """
461
- data = request.get_json(silent=True) or {}
462
-
463
- userId = data.get("userId")
464
- testId = data.get("testId")
465
- testTitle = data.get("testTitle")
466
- timeSpent = data.get("timeSpent")
467
- totalTime = data.get("totalTime")
468
- answers = data.get("answers")
469
-
470
- if not all([userId, testId, answers]):
471
- return jsonify({"error": "Missing required fields: userId, testId, answers"}), 400
472
- if not isinstance(answers, list):
473
- return jsonify({"error": "Answers must be a list"}), 400
474
-
475
- submittedAt = datetime.now().isoformat()
476
-
477
- # 🧠 Fetch original test data (includes correct answers)
478
- test_data = fetch_test_by_testId(testId)
479
- if not test_data:
480
- return jsonify({"error": "Test not found"}), 404
481
-
482
- questions = test_data.get("questions", [])
483
- if isinstance(questions, str):
484
- try:
485
- questions = json.loads(questions)
486
- except Exception:
487
- questions = []
488
-
489
- # Build quick lookup of correct answers
490
- correct_map = {q.get("questionId"): q.get("answer") for q in questions}
491
-
492
- totalQuestions = len(correct_map)
493
- total_correct = 0
494
- results = []
495
-
496
- # ✅ Compare each submitted answer
497
- for ans in answers:
498
- qid = ans.get("questionId")
499
- qtext = ans.get("question")
500
- user_ans = ans.get("your_answer")
501
-
502
- # Try to get correct answer using questionId first, then question text
503
- correct_ans = None
504
- if qid and qid in correct_map:
505
- correct_ans = correct_map.get(qid)
506
- elif qtext:
507
- for q in questions:
508
- if qtext.strip().lower() == q.get("question", "").strip().lower():
509
- correct_ans = q.get("answer")
510
- qid = q.get("questionId")
511
- break
512
-
513
- is_correct = (user_ans == correct_ans)
514
-
515
- if is_correct:
516
- total_correct += 1
517
-
518
- results.append(OrderedDict([
519
- ("questionId", qid),
520
- ("your_answer", user_ans),
521
- ("correct_answer", correct_ans),
522
- ("is_correct", is_correct)
523
- ]))
524
-
525
- # 🧮 Calculate score
526
- score = round((total_correct / totalQuestions) * 100, 2) if totalQuestions > 0 else 0.0
527
-
528
- # 💾 Store submission attempt in Qdrant or DB
529
- is_stored, attemptId = store_submitted_test(
530
- userId=userId,
531
- testId=testId,
532
- testTitle=testTitle,
533
- timeSpent=timeSpent,
534
- totalTime=totalTime,
535
- submittedAt=submittedAt,
536
- detailed_results=results,
537
- score=score,
538
- total_questions=totalQuestions,
539
- total_correct=total_correct
540
- )
541
-
542
- if not is_stored:
543
- return jsonify({"error": "Failed to store submission"}), 500
544
-
545
- # 📦 Final response
546
- response = OrderedDict([
547
- ("attemptId", attemptId),
548
- ("userId", userId),
549
- ("testId", testId),
550
- ("testTitle", testTitle),
551
- ("submittedAt", submittedAt),
552
- ("timeSpent", timeSpent),
553
- ("total_questions", totalQuestions),
554
- ("total_correct", total_correct),
555
- ("score", score),
556
- ("detailed_results", results)
557
- ])
558
-
559
- return jsonify(response)
560
-
561
- @app.route("/submitted_tests/<userId>", methods=["GET"])
562
- def submitted_tests_history(userId):
563
- """
564
- API to fetch a list of all submitted test sessions for a given user.
565
- """
566
- if not userId:
567
- return jsonify({"error": "userId is required"}), 400
568
-
569
- submitted_tests = submitted_tests_by_userId(userId)
570
-
571
- if submitted_tests is None:
572
- return jsonify({"error": "An error occurred while fetching submitted tests"}), 500
573
-
574
- if not submitted_tests:
575
- return jsonify({"message": "No submitted tests found for this user"}), 200
576
-
577
- return jsonify(submitted_tests), 200
578
-
579
-
580
- @app.route("/submitted_test/<testId>", methods=["GET"])
581
- def get_single_submitted_test(testId):
582
- """
583
- Fetch details of one submitted test by testId.
584
- """
585
- if not testId:
586
- return jsonify({"error": "testId is required"}), 400
587
-
588
- result = fetch_submitted_test_by_testId(testId)
589
-
590
- if not result:
591
- return jsonify({"message": "No submitted test found"}), 404
592
-
593
- return jsonify(result), 200
594
-
595
-
596
- @app.route("/question_bank/<generatedQAId>", methods=["PUT"])
597
- def edit_question_bank(generatedQAId):
598
- """
599
- Unified API to perform add, edit, or delete operations on questions,
600
- and update the question bank's Title and Description.
601
-
602
- Accepts both:
603
- 1. {
604
- "title": "English Test",
605
- "description": "Updated chapter 1 test",
606
- "edits": [ { "operation": "edit", "data": {...}} ]
607
- }
608
- 2. [ { "operation": "edit", "data": {...}} ] ← Legacy (frontend-only edits)
609
- """
610
-
611
- # Step 1: Parse request JSON
612
- payload = request.get_json(silent=True) or {}
613
-
614
- # Handle both dict and list payloads
615
- if isinstance(payload, list):
616
- edits = payload
617
- new_title = None
618
- new_description = None
619
- else:
620
- edits = payload.get("edits")
621
- new_title = payload.get("title")
622
- new_description = payload.get("description")
623
-
624
- metadata_update_status = {
625
- "title_updated": False,
626
- "description_updated": False,
627
- "success": True
628
- }
629
-
630
- # --- Step 2: Update Metadata (Title / Description) ---
631
- try:
632
- if new_title is not None or new_description is not None:
633
- metadata_update_status = update_question_bank_metadata(
634
- generatedQAId=generatedQAId,
635
- title=new_title,
636
- description=new_description
637
- )
638
-
639
- # Handle metadata update failure
640
- if not metadata_update_status.get("success", True):
641
- return jsonify({
642
- "error": f"Failed to update metadata for Question Bank ID: {generatedQAId}"
643
- }), 500
644
- except Exception as e:
645
- print(f"[ERROR] Metadata update failed: {str(e)}")
646
- metadata_update_status["success"] = False
647
-
648
- # --- Step 3: Process Question-Level Edits ---
649
- if edits and isinstance(edits, list):
650
- for edit in edits:
651
- try:
652
- operation = edit.get("operation")
653
- data = edit.get("data")
654
-
655
- if not operation or not data:
656
- continue
657
-
658
- if operation == "add":
659
- add_single_question(generatedQAId, data)
660
-
661
- elif operation == "edit":
662
- questionId = data.get("questionId")
663
- if questionId:
664
- update_single_question(questionId, data)
665
-
666
- elif operation == "delete":
667
- questionId = data.get("questionId")
668
- if questionId:
669
- delete_single_question(questionId)
670
-
671
- else:
672
- print(f"[WARN] Unknown operation '{operation}' ignored.")
673
-
674
- except Exception as e:
675
- print(f"[ERROR] Failed to process edit operation: {str(e)}")
676
- continue
677
-
678
- # --- Step 4: Fetch Updated Data for Response ---
679
- try:
680
- updated_data = fetch_mcqs(generatedQAId=generatedQAId)
681
- except Exception as e:
682
- print(f"[ERROR] Failed to fetch updated question bank: {str(e)}")
683
- updated_data = None
684
-
685
- if not updated_data:
686
- return jsonify({
687
- "error": "Update processed, but the question bank was not found.",
688
- "generatedQAId_used": generatedQAId
689
- }), 404
690
-
691
- # ✅ --- Step 5: Compute answerFound flag ---
692
- mcqs = updated_data[0]["metadata"].get("mcqs", [])
693
- all_have_answers = True
694
- for q in mcqs:
695
- ans = q.get("answer")
696
- if not (ans and str(ans).strip()):
697
- all_have_answers = False
698
- break
699
-
700
- # ✅ --- Step 6: Update Qdrant MCQ bank with answerFound flag ---
701
- update_answer_flag_in_qdrant(generatedQAId, all_have_answers)
702
-
703
- updated_questions_count = len(mcqs)
704
-
705
- # ✅ --- Step 7: Return Success Response ---
706
- return jsonify({
707
- "message": "Question bank updated successfully",
708
- "title_updated": metadata_update_status.get("title_updated", False),
709
- "description_updated": metadata_update_status.get("description_updated", False),
710
- "updated_questions_count": updated_questions_count,
711
- "answerFound": all_have_answers
712
- }), 200
713
-
714
-
715
- @app.route("/create_manual_question_bank", methods=["POST"])
716
- def create_manual_question_bank():
717
- """
718
- API to create a new question bank and populate it with a list of questions
719
- in a single request for a smoother user experience.
720
- """
721
- data = request.get_json(silent=True) or request.form.to_dict()
722
- user_id = data.get("userId")
723
- title = data.get("title")
724
- description = data.get("description")
725
- raw_mcqs = data.get("questions", []) # Expects a list of question objects
726
-
727
- if not all([user_id, title, description]) or not isinstance(raw_mcqs, list):
728
- return jsonify({"error": "userId, title, description, and a list of 'questions' are required"}), 400
729
-
730
- if not raw_mcqs:
731
- return jsonify({"error": "Question bank must contain at least one question."}), 400
732
-
733
- indexed_mcqs = []
734
-
735
- # 1. Format and Index MCQs (similar to your upload_pdf route logic)
736
- for i, mcq in enumerate(raw_mcqs):
737
- # Ensure options are properly formatted (if they come as a dict from the client)
738
- if 'options' in mcq and isinstance(mcq['options'], dict):
739
- # We need to ensure the options are stored as a JSON string
740
- # as required by the ChromaDB metadata constraint (as discovered earlier).
741
- mcq['options'] = json.dumps(mcq['options'])
742
-
743
- # NOTE: If your database requires questionId/documentIndex, they must be set here.
744
- # However, we will assume 'store_mcqs_for_manual_creation' handles questionId and documentIndex assignment.
745
- mcq['documentIndex'] = i
746
- mcq['questionId'] = str(uuid.uuid4())
747
- indexed_mcqs.append(mcq)
748
-
749
- # 2. Store Metadata and Questions (using a modified store function)
750
- try:
751
- # Create a function similar to store_mcqs but for manual data
752
- generated_qa_id = store_mcqs_for_manual_creation(
753
- user_id,
754
- title,
755
- description,
756
- indexed_mcqs
757
- )
758
- except Exception as e:
759
- print(f"Error storing manual question bank: {e}")
760
- return jsonify({"error": "Failed to create and store question bank"}), 500
761
-
762
- return jsonify({
763
- "message": "Question bank created and populated successfully",
764
- "generatedQAId": generated_qa_id,
765
- "userId": user_id,
766
- "title": title,
767
- "questions_count": len(indexed_mcqs)
768
- }), 201
769
-
770
-
771
- @app.route("/question_bank/<generatedQAId>", methods=["DELETE"])
772
- def delete_question_bank(generatedQAId):
773
- """
774
- API to delete an entire question bank (metadata and all associated questions).
775
- """
776
- if not generatedQAId:
777
- return jsonify({"error": "generatedQAId is required"}), 400
778
-
779
- # Assume this function handles the deletion from both the main
780
- # and the questions collection using the generatedQAId.
781
- success = delete_mcq_bank(generatedQAId)
782
-
783
- if success:
784
- return jsonify({
785
- "message": f"Question bank '{generatedQAId}' and all associated questions deleted successfully."
786
- }), 200
787
- else:
788
- # Return 404 if the bank wasn't found to delete, or 500 on database error
789
- return jsonify({
790
- "error": f"Failed to delete question bank '{generatedQAId}'. It may not exist."
791
- }), 200
792
-
793
-
794
- @app.route("/submitted_test/<testId>", methods=["DELETE"])
795
- def delete_submitted_test(testId):
796
- """
797
- API to delete a specific submitted test session result by its ID.
798
- """
799
- if not testId:
800
- return jsonify({"error": "testId is required"}), 400
801
-
802
- success = delete_submitted_test_by_id(testId)
803
-
804
- if success:
805
- return jsonify({
806
- "message": f"Submitted test result '{testId}' deleted successfully."
807
- }), 200
808
- else:
809
- return jsonify({
810
- "error": f"Failed to delete submitted test result '{testId}'. It may not exist."
811
- }), 404
812
-
813
-
814
- @app.route("/paper_sets/<testId>", methods=["DELETE"])
815
- def delete_test_session(testId):
816
- """
817
- API to delete a specific test session by its ID.
818
- """
819
- if not testId:
820
- return jsonify({"error": "testId is required"}), 400
821
-
822
- # Assume this function handles the deletion from test_sessions_collection
823
- success = delete_test_session_by_id(testId)
824
-
825
- if success:
826
- return jsonify({
827
- "message": f"Test '{testId}' deleted successfully."
828
- }), 200
829
- else:
830
- return jsonify({
831
- "message": f"Failed to delete '{testId}' "
832
- }), 200
833
-
834
-
835
- @app.route("/test_attempt/<attemptId>", methods=["DELETE"])
836
- def delete_submitted_test_attempt_api(attemptId):
837
- """
838
- API to delete a specific submitted test attempt by attemptId.
839
- """
840
- if not attemptId:
841
- return jsonify({"error": "attemptId is required"}), 400
842
-
843
- success = delete_submitted_test_attempt(attemptId)
844
- if not success:
845
- return jsonify({"error": "Failed to delete attempt"}), 200
846
-
847
- return jsonify({
848
- "message": f"Attempt {attemptId} deleted successfully"
849
- }), 200
850
-
851
-
852
- @app.route("/paper_sets/<testId>", methods=["PUT"])
853
- def edit_paperset(testId):
854
- """
855
- Update specific fields of a test session.
856
- Allows partial updates for test metadata and individual questions.
857
- """
858
- payload = request.get_json(silent=True) or {}
859
-
860
- if not testId:
861
- return jsonify({"error": "testId is required"}), 400
862
-
863
- # 1️⃣ Fetch existing test session
864
- existing_record = fetch_test_by_testId(testId)
865
- if not existing_record:
866
- return jsonify({"error": f"Test session '{testId}' not found"}), 404
867
-
868
- updated_data = existing_record.copy()
869
-
870
- # Extract fields
871
- edits = payload.get("edits", [])
872
- new_title = payload.get("testTitle")
873
- new_total_time = payload.get("totalTime")
874
-
875
- # --- Step 2: Update Top-Level Fields ---
876
- if new_title is not None:
877
- updated_data["testTitle"] = new_title
878
-
879
- if new_total_time is not None:
880
- updated_data["totalTime"] = new_total_time
881
-
882
- # --- Step 3: Question Operations ---
883
- existing_questions = {q["questionId"]: q for q in updated_data.get("questions", [])}
884
-
885
- for edit in edits:
886
- operation = edit.get("operation")
887
- data = edit.get("data")
888
-
889
- if not operation or not data:
890
- continue
891
-
892
- # ---------- ADD ----------
893
- if operation == "add":
894
- qid = data.get("questionId")
895
- if not qid:
896
- continue
897
-
898
- # Set default fields for new question
899
- data.setdefault("documentIndex", len(existing_questions))
900
- data.setdefault("testIndex", len(existing_questions) + 1)
901
- data.setdefault("userId", updated_data.get("userId"))
902
- data.setdefault("generatedQAId", updated_data.get("generatedQAId"))
903
- data.setdefault("passage", "")
904
- data.setdefault("image", None)
905
- data.setdefault("noise", "")
906
-
907
- existing_questions[qid] = data
908
-
909
- # ---------- EDIT ----------
910
- elif operation == "edit":
911
- qid = data.get("questionId")
912
- if qid and qid in existing_questions:
913
- for key, value in data.items():
914
- existing_questions[qid][key] = value
915
-
916
- # ---------- DELETE ----------
917
- elif operation == "delete":
918
- qid = data.get("questionId")
919
- if qid in existing_questions:
920
- del existing_questions[qid]
921
-
922
- # Sort after update
923
- updated_data["questions"] = sorted(
924
- list(existing_questions.values()),
925
- key=lambda q: q.get("documentIndex", 999999)
926
- )
927
-
928
- # --- Step 4: Save back ---
929
- success = update_test_session(testId, updated_data)
930
-
931
- if success:
932
- return jsonify({
933
- "message": "Test session updated successfully",
934
- "testId": testId,
935
- "updated_fields": list(payload.keys())
936
- }), 200
937
- else:
938
- return jsonify({"error": "Failed to update test session"}), 500
939
-
940
-
941
- if __name__ == '__main__':
942
- app.run(host="0.0.0.0", port=10000, debug=True)
 
1
+ import os
2
+ import json
3
+ import re
4
+ import torch
5
+ import gradio as gr
6
+ import google.generativeai as genai
7
+ from sentence_transformers import SentenceTransformer, util
8
+
9
+ # ============================================================
10
+ # CONFIG
11
+ # ============================================================
12
+ GEMINI_API_KEY = "AIzaSyBrbLGXkSdXReb0lUucYqcNCNBkvS-RBFw"
13
+ if not GEMINI_API_KEY:
14
+ raise RuntimeError("Set GEMINI_API_KEY environment variable")
15
+
16
+ genai.configure(api_key=GEMINI_API_KEY)
17
+
18
+ MODEL = genai.GenerativeModel("gemini-pro")
19
+
20
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
22
+ SIM_THRESHOLD = 0.55
23
+
24
+ print("Loading embedding model...")
25
+ embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE)
26
+ print("✅ Ready")
27
+
28
+ # ============================================================
29
+ # UTILS
30
+ # ============================================================
31
+ def split_sentences(text):
32
+ return [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if len(s.strip()) > 5]
33
+
34
+ def gemini(prompt, max_tokens=256):
35
+ response = MODEL.generate_content(
36
+ prompt,
37
+ generation_config=genai.types.GenerationConfig(
38
+ temperature=0.0,
39
+ max_output_tokens=max_tokens
40
+ )
41
+ )
42
+ return response.text.strip()
43
+
44
+ def safe_json(text):
45
+ try:
46
+ return json.loads(text)
47
+ except:
48
+ start, end = text.find("{"), text.rfind("}") + 1
49
+ if start != -1 and end != -1:
50
+ try:
51
+ return json.loads(text[start:end])
52
+ except:
53
+ return None
54
+ return None
55
+
56
+ # ============================================================
57
+ # STEP 1: INTENT
58
+ # ============================================================
59
+ def detect_intent(question):
60
+ prompt = f"""
61
+ Classify the question intent. Choose ONE:
62
+ FACTUAL, EXPLANATORY, CHARACTER_ARC, PROCESS, COMPARISON
63
+
64
+ Question:
65
+ {question}
66
+
67
+ Output ONLY the label.
68
+ """
69
+ out = gemini(prompt, 20)
70
+ return out if out in {
71
+ "FACTUAL","EXPLANATORY","CHARACTER_ARC","PROCESS","COMPARISON"
72
+ } else "EXPLANATORY"
73
+
74
+ # ============================================================
75
+ # STEP 2: RUBRIC GENERATION
76
+ # ============================================================
77
+ def generate_rubric(kb, question, intent):
78
+ prompt = f"""
79
+ You are an examiner.
80
+
81
+ Using ONLY the knowledge base, create a grading rubric for the question.
82
+ Each item must be an atomic idea a student must mention.
83
+
84
+ Rules:
85
+ - 3 to 6 criteria
86
+ - No paraphrasing the question
87
+ - No explanations
88
+ - Capture progression if relevant
89
+ - STRICT JSON ONLY
90
+
91
+ Format:
92
+ {{ "criteria": ["criterion 1", "criterion 2"] }}
93
+
94
+ Knowledge Base:
95
+ {kb}
96
+
97
+ Question:
98
+ {question}
99
+
100
+ Intent:
101
+ {intent}
102
+ """
103
+ raw = gemini(prompt, 300)
104
+ parsed = safe_json(raw)
105
+ return parsed["criteria"] if parsed and "criteria" in parsed else []
106
+
107
+ # ============================================================
108
+ # STEP 3: SEMANTIC MATCHING
109
+ # ============================================================
110
+ def score(answer, criteria):
111
+ sents = split_sentences(answer)
112
+ ans_emb = embedder.encode(sents, convert_to_tensor=True)
113
+
114
+ results = []
115
+ for crit in criteria:
116
+ crit_emb = embedder.encode(crit, convert_to_tensor=True)
117
+ sims = util.cos_sim(crit_emb, ans_emb)[0]
118
+ best = float(torch.max(sims)) if sims.numel() else 0.0
119
+
120
+ results.append({
121
+ "criterion": crit,
122
+ "score": round(best, 3),
123
+ "satisfied": best >= SIM_THRESHOLD
124
+ })
125
+ return results
126
+
127
+ # ============================================================
128
+ # FINAL VERDICT
129
+ # ============================================================
130
+ def verdict(scored):
131
+ hit = sum(c["satisfied"] for c in scored)
132
+ total = len(scored)
133
+
134
+ if hit == total:
135
+ return " CORRECT"
136
+ if hit >= max(1, total // 2):
137
+ return "⚠️ PARTIALLY CORRECT"
138
+ return " INCORRECT"
139
+
140
+ # ============================================================
141
+ # PIPELINE
142
+ # ============================================================
143
+ def evaluate(answer, question, kb):
144
+ intent = detect_intent(question)
145
+ rubric = generate_rubric(kb, question, intent)
146
+ scored = score(answer, rubric) if rubric else []
147
+
148
+ return {
149
+ "intent": intent,
150
+ "rubric": rubric,
151
+ "scoring": scored,
152
+ "final_verdict": verdict(scored) if rubric else "⚠️ NO RUBRIC"
153
+ }
154
+
155
+ # ============================================================
156
+ # UI
157
+ # ============================================================
158
+ with gr.Blocks() as demo:
159
+ gr.Markdown("## 🧠 Gemini-powered Answer Grader")
160
+
161
+ kb = gr.Textbox(label="Knowledge Base", lines=8)
162
+ q = gr.Textbox(label="Question")
163
+ a = gr.Textbox(label="Student Answer", lines=6)
164
+
165
+ out = gr.JSON(label="Evaluation")
166
+
167
+ gr.Button("Evaluate").click(evaluate, [a, q, kb], out)
168
+
169
+ demo.launch()