Spaces:
Running
Running
| from flask import Blueprint, jsonify, current_app, render_template, request | |
| from flask_login import login_required, current_user | |
| from utils import get_db_connection | |
| import os | |
| import time | |
| import json | |
| from processing import resize_image_if_needed, call_nim_ocr_api | |
| from gemini_classifier import classify_questions_with_gemini | |
| from gemma_classifier import GemmaClassifier | |
| from nova_classifier import classify_questions_with_nova | |
| import requests | |
| from nvidia_prompts import BIOLOGY_PROMPT_TEMPLATE, CHEMISTRY_PROMPT_TEMPLATE, PHYSICS_PROMPT_TEMPLATE, MATHEMATICS_PROMPT_TEMPLATE | |
| classifier_bp = Blueprint('classifier_bp', __name__) | |
| # Instantiate classifiers | |
| gemma_classifier = GemmaClassifier() | |
| def get_nvidia_prompt(subject, input_questions): | |
| if subject.lower() == 'biology': | |
| return BIOLOGY_PROMPT_TEMPLATE.format(input_questions=input_questions) | |
| elif subject.lower() == 'chemistry': | |
| return CHEMISTRY_PROMPT_TEMPLATE.format(input_questions=input_questions) | |
| elif subject.lower() == 'physics': | |
| return PHYSICS_PROMPT_TEMPLATE.format(input_questions=input_questions) | |
| elif subject.lower() == 'mathematics': | |
| return MATHEMATICS_PROMPT_TEMPLATE.format(input_questions=input_questions) | |
| return None | |
| def get_topic_suggestions(): | |
| data = request.json | |
| question_text = data.get('question_text') | |
| image_id = data.get('image_id') | |
| subject = data.get('subject') | |
| if not subject: | |
| return jsonify({'error': 'Subject is required'}), 400 | |
| # If text is missing but we have image_id, try to get from DB or run OCR | |
| if not question_text and image_id: | |
| try: | |
| conn = get_db_connection() | |
| # Check DB first | |
| row = conn.execute('SELECT question_text, processed_filename, i.session_id FROM questions q JOIN images i ON q.image_id = i.id WHERE i.id = ?', (image_id,)).fetchone() | |
| if row: | |
| if row['question_text']: | |
| question_text = row['question_text'] | |
| else: | |
| # Run OCR | |
| processed_filename = row['processed_filename'] | |
| session_id = row['session_id'] | |
| if processed_filename: | |
| image_path = os.path.join(current_app.config['PROCESSED_FOLDER'], processed_filename) | |
| if os.path.exists(image_path): | |
| image_bytes = resize_image_if_needed(image_path) | |
| ocr_result = call_nim_ocr_api(image_bytes) | |
| if ocr_result.get('data') and ocr_result['data'][0].get('text_detections'): | |
| question_text = " ".join(item['text_prediction']['text'] for item in ocr_result['data'][0]['text_detections']) | |
| # Save back to DB | |
| conn.execute('UPDATE questions SET question_text = ? WHERE image_id = ?', (question_text, image_id)) | |
| conn.commit() | |
| conn.close() | |
| except Exception as e: | |
| current_app.logger.error(f"Error fetching/OCRing text for image {image_id}: {e}") | |
| return jsonify({'error': f"OCR failed: {str(e)}"}), 500 | |
| if not question_text: | |
| return jsonify({'error': 'Could not obtain question text (OCR failed or no text found).'}), 400 | |
| # Prepare prompt | |
| # The prompt expects "Input Questions: [Insert ...]". | |
| # We will format the single question as "1. {text}" to match the pattern somewhat, | |
| # though the prompt handles raw text too. | |
| input_formatted = f"1. {question_text}" | |
| prompt_content = get_nvidia_prompt(subject, input_formatted) | |
| if not prompt_content: | |
| return jsonify({'error': f'Unsupported subject: {subject}'}), 400 | |
| # Call NVIDIA API | |
| nvidia_api_key = os.environ.get('NVIDIA_API_KEY') | |
| if not nvidia_api_key: | |
| return jsonify({'error': 'NVIDIA_API_KEY not set'}), 500 | |
| invoke_url = 'https://integrate.api.nvidia.com/v1/chat/completions' | |
| headers = { | |
| 'Authorization': f'Bearer {nvidia_api_key}', | |
| 'Accept': 'application/json', | |
| 'Content-Type': 'application/json' | |
| } | |
| payload = { | |
| "model": "nvidia/nemotron-3-nano-30b-a3b", | |
| "messages": [ | |
| { | |
| "content": prompt_content, | |
| "role": "user" | |
| } | |
| ], | |
| "temperature": 0.2, # Slightly higher for variety in top-k if supported, but here we just want accurate multiple suggestions | |
| "top_p": 1, | |
| "max_tokens": 1024, | |
| "stream": False | |
| } | |
| try: | |
| response = requests.post(invoke_url, headers=headers, json=payload, timeout=30) | |
| response.raise_for_status() | |
| result = response.json() | |
| content = result['choices'][0]['message']['content'] | |
| # Parse JSON from content (it might be wrapped in markdown code blocks) | |
| if "```json" in content: | |
| content = content.split("```json")[1].split("```")[0].strip() | |
| elif "```" in content: | |
| content = content.split("```")[1].split("```")[0].strip() | |
| data = json.loads(content) | |
| # Extract suggestions | |
| suggestions = [] | |
| if data.get('data') and len(data['data']) > 0: | |
| primary_chapter = data['data'][0].get('chapter_title', 'Unclassified') | |
| suggestions.append(primary_chapter) | |
| # Check for alternative suggestions if the model provides them (we will update prompts to support this) | |
| if 'other_possible_chapters' in data['data'][0]: | |
| others = data['data'][0]['other_possible_chapters'] | |
| if isinstance(others, list): | |
| suggestions.extend(others) | |
| return jsonify({'success': True, 'suggestions': suggestions, 'full_response': data}) | |
| except Exception as e: | |
| current_app.logger.error(f"NVIDIA API Error: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def update_question_classification_single(): | |
| data = request.json | |
| image_id = data.get('image_id') | |
| subject = data.get('subject') | |
| chapter = data.get('chapter') | |
| if not image_id: | |
| return jsonify({'error': 'Image ID is required'}), 400 | |
| try: | |
| conn = get_db_connection() | |
| # Security: Check ownership via session -> images | |
| image_owner = conn.execute(""" | |
| SELECT s.user_id | |
| FROM images i | |
| JOIN sessions s ON i.session_id = s.id | |
| WHERE i.id = ? | |
| """, (image_id,)).fetchone() | |
| if not image_owner or image_owner['user_id'] != current_user.id: | |
| conn.close() | |
| return jsonify({'error': 'Unauthorized'}), 403 | |
| conn.execute( | |
| 'UPDATE questions SET subject = ?, chapter = ? WHERE image_id = ?', | |
| (subject, chapter, image_id) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| return jsonify({'success': True}) | |
| except Exception as e: | |
| current_app.logger.error(f"Error updating question classification for image {image_id}: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def edit_classified_questions(): | |
| """Renders the page for editing classified questions.""" | |
| conn = get_db_connection() | |
| AVAILABLE_SUBJECTS = ["Biology", "Chemistry", "Physics", "Mathematics"] | |
| # Security: Fetch questions belonging to the current user | |
| questions_from_db = conn.execute(""" | |
| SELECT q.id, q.question_text, q.chapter, q.subject, q.tags | |
| FROM questions q | |
| JOIN sessions s ON q.session_id = s.id | |
| WHERE s.user_id = ? AND q.subject IS NOT NULL AND q.chapter IS NOT NULL | |
| ORDER BY q.id | |
| """, (current_user.id,)).fetchall() | |
| questions = [] | |
| for q in questions_from_db: | |
| q_dict = dict(q) | |
| plain_text = q_dict.get('question_text') or '' # Handle None | |
| q_dict['question_text_plain'] = (plain_text[:100] + '...') if len(plain_text) > 100 else plain_text | |
| questions.append(q_dict) | |
| # Suggestions should also be user-specific | |
| chapters = conn.execute('SELECT DISTINCT q.chapter FROM questions q JOIN sessions s ON q.session_id = s.id WHERE s.user_id = ? AND q.chapter IS NOT NULL ORDER BY q.chapter', (current_user.id,)).fetchall() | |
| tags_query = conn.execute('SELECT DISTINCT q.tags FROM questions q JOIN sessions s ON q.session_id = s.id WHERE s.user_id = ? AND q.tags IS NOT NULL AND q.tags != \'\'', (current_user.id,)).fetchall() | |
| all_tags = set() | |
| for row in tags_query: | |
| tags = [tag.strip() for tag in row['tags'].split(',')] | |
| all_tags.update(tags) | |
| conn.close() | |
| return render_template('classified_edit.html', | |
| questions=questions, | |
| chapters=[c['chapter'] for c in chapters], | |
| all_tags=sorted(list(all_tags)), | |
| available_subjects=AVAILABLE_SUBJECTS) | |
| def update_classified_question(question_id): | |
| """Handles updating a question's metadata.""" | |
| data = request.json | |
| new_chapter = data.get('chapter') | |
| new_subject = data.get('subject') | |
| if not new_chapter or not new_subject: | |
| return jsonify({'error': 'Chapter and Subject cannot be empty.'}), 400 | |
| try: | |
| conn = get_db_connection() | |
| # Security: Check ownership before update | |
| question_owner = conn.execute("SELECT s.user_id FROM questions q JOIN sessions s ON q.session_id = s.id WHERE q.id = ?", (question_id,)).fetchone() | |
| if not question_owner or question_owner['user_id'] != current_user.id: | |
| conn.close() | |
| return jsonify({'error': 'Unauthorized'}), 403 | |
| conn.execute( | |
| 'UPDATE questions SET chapter = ?, subject = ? WHERE id = ?', | |
| (new_chapter, new_subject, question_id) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| return jsonify({'success': True}) | |
| except Exception as e: | |
| current_app.logger.error(f"Error updating question {question_id}: {repr(e)}") | |
| return jsonify({'error': str(e)}), 500 | |
| def delete_classified_question(question_id): | |
| """Handles deleting a classified question.""" | |
| try: | |
| conn = get_db_connection() | |
| # Security: Check ownership before delete | |
| question_owner = conn.execute("SELECT s.user_id FROM questions q JOIN sessions s ON q.session_id = s.id WHERE q.id = ?", (question_id,)).fetchone() | |
| if not question_owner or question_owner['user_id'] != current_user.id: | |
| conn.close() | |
| return jsonify({'error': 'Unauthorized'}), 403 | |
| # Update the question to remove classification | |
| conn.execute('UPDATE questions SET subject = NULL, chapter = NULL WHERE id = ?', (question_id,)) | |
| conn.commit() | |
| conn.close() | |
| return jsonify({'success': True}) | |
| except Exception as e: | |
| current_app.logger.error(f"Error deleting question {question_id}: {repr(e)}") | |
| return jsonify({'error': str(e)}), 500 | |
| def delete_many_classified_questions(): | |
| """Handles bulk deleting classified questions.""" | |
| data = request.json | |
| question_ids = data.get('ids', []) | |
| if not question_ids: | |
| return jsonify({'error': 'No question IDs provided.'}), 400 | |
| try: | |
| conn = get_db_connection() | |
| # Security: Filter IDs to only those owned by the user | |
| placeholders = ','.join('?' for _ in question_ids) | |
| owned_q_ids_rows = conn.execute(f""" | |
| SELECT q.id FROM questions q | |
| JOIN sessions s ON q.session_id = s.id | |
| WHERE q.id IN ({placeholders}) AND s.user_id = ? | |
| """, (*question_ids, current_user.id)).fetchall() | |
| owned_q_ids = [row['id'] for row in owned_q_ids_rows] | |
| if not owned_q_ids: | |
| conn.close() | |
| return jsonify({'success': True, 'message': 'No owned questions to delete.'}) | |
| update_placeholders = ','.join('?' for _ in owned_q_ids) | |
| conn.execute(f'UPDATE questions SET subject = NULL, chapter = NULL WHERE id IN ({update_placeholders})', owned_q_ids) | |
| conn.commit() | |
| conn.close() | |
| return jsonify({'success': True}) | |
| except Exception as e: | |
| current_app.logger.error(f"Error deleting questions: {repr(e)}") | |
| return jsonify({'error': str(e)}), 500 | |
| from rich.table import Table | |
| from rich.console import Console | |
| def extract_and_classify_all(session_id): | |
| try: | |
| conn = get_db_connection() | |
| # Security: Check ownership of the session | |
| session_owner = conn.execute('SELECT user_id FROM sessions WHERE id = ?', (session_id,)).fetchone() | |
| if not session_owner or session_owner['user_id'] != current_user.id: | |
| conn.close() | |
| return jsonify({'error': 'Unauthorized'}), 403 | |
| images = conn.execute( | |
| "SELECT id, processed_filename FROM images WHERE session_id = ? AND image_type = 'cropped' ORDER BY id", | |
| (session_id,) | |
| ).fetchall() | |
| if not images: | |
| conn.close() | |
| return jsonify({'error': 'No cropped images found in session'}), 404 | |
| current_app.logger.info(f"Found {len(images)} images to process for user {current_user.id}.") | |
| question_texts = [] | |
| image_ids = [] | |
| for image in images: | |
| image_id = image['id'] | |
| processed_filename = image['processed_filename'] | |
| if not processed_filename: | |
| continue | |
| image_path = os.path.join(current_app.config['PROCESSED_FOLDER'], processed_filename) | |
| if not os.path.exists(image_path): | |
| continue | |
| image_bytes = resize_image_if_needed(image_path) | |
| ocr_result = call_nim_ocr_api(image_bytes) | |
| current_app.logger.info(f"NVIDIA OCR Result for image {image_id}: {ocr_result}") | |
| if not ocr_result.get('data') or not ocr_result['data'][0].get('text_detections'): | |
| current_app.logger.error(f"NVIDIA OCR result for image {image_id} does not contain 'text_detections' key. Full response: {ocr_result}") | |
| continue | |
| text = " ".join(item['text_prediction']['text'] for item in ocr_result['data'][0]['text_detections']) | |
| conn.execute('UPDATE questions SET question_text = ? WHERE image_id = ?', (text, image_id)) | |
| current_app.logger.info(f"Updated question_text for image_id: {image_id}") | |
| question_texts.append(text) | |
| image_ids.append(image_id) | |
| conn.commit() | |
| # --- Batch Processing and Classification --- | |
| batch_size = 7 # Default batch size | |
| total_questions = len(question_texts) | |
| num_batches = (total_questions + batch_size - 1) // batch_size | |
| total_update_count = 0 | |
| for i in range(num_batches): | |
| start_index = i * batch_size | |
| end_index = start_index + batch_size | |
| batch_texts = question_texts[start_index:end_index] | |
| batch_image_ids = image_ids[start_index:end_index] | |
| if not batch_texts: | |
| continue | |
| current_app.logger.info(f"Processing Batch {i+1}/{num_batches}...") | |
| # Choose classifier based on user preference | |
| classifier_model = getattr(current_user, 'classifier_model', 'gemini') | |
| if classifier_model == 'nova': | |
| current_app.logger.info(f"Using Nova classifier for user {current_user.id}") | |
| classification_result = classify_questions_with_nova(batch_texts, start_index=start_index) | |
| model_name = "Nova" | |
| elif classifier_model == 'gemma': | |
| current_app.logger.info(f"Using Gemma classifier for user {current_user.id}") | |
| classification_result = gemma_classifier.classify(batch_texts, start_index=start_index) | |
| model_name = "Gemma" | |
| else: | |
| current_app.logger.info(f"Using Gemini classifier for user {current_user.id}") | |
| classification_result = classify_questions_with_gemini(batch_texts, start_index=start_index) | |
| model_name = "Gemini" | |
| # Log the result to the terminal | |
| current_app.logger.info(f"--- Classification Result ({model_name}) for Batch {i+1} ---") | |
| current_app.logger.info(json.dumps(classification_result, indent=2)) | |
| current_app.logger.info("---------------------------------------------") | |
| if not classification_result or not classification_result.get('data'): | |
| current_app.logger.error(f'{model_name} classifier did not return valid data for batch {i+1}.') | |
| continue # Move to the next batch | |
| # --- Immediate DB Update for the Batch --- | |
| batch_update_count = 0 | |
| for item in classification_result.get('data', []): | |
| item_index_global = item.get('index') # This is the global index (e.g., 1 to 14) | |
| if item_index_global is not None: | |
| # Find the corresponding local index in our full list | |
| try: | |
| # The item_index_global is 1-based, our list is 0-based | |
| local_list_index = item_index_global - 1 | |
| # Find the image_id for that question | |
| matched_id = image_ids[local_list_index] | |
| except IndexError: | |
| current_app.logger.error(f"Classifier returned an out-of-bounds index: {item_index_global}") | |
| continue | |
| new_subject = item.get('subject') | |
| new_chapter = item.get('chapter_title') | |
| # Ensure new_subject and new_chapter are strings, not lists | |
| if isinstance(new_subject, list): | |
| new_subject = ', '.join(str(x) for x in new_subject) if new_subject else 'Unclassified' | |
| elif new_subject is None: | |
| new_subject = 'Unclassified' | |
| else: | |
| new_subject = str(new_subject) | |
| if isinstance(new_chapter, list): | |
| new_chapter = ', '.join(str(x) for x in new_chapter) if new_chapter else 'Unclassified' | |
| elif new_chapter is None: | |
| new_chapter = 'Unclassified' | |
| else: | |
| new_chapter = str(new_chapter) | |
| if new_subject and new_subject != 'Unclassified' and new_chapter and new_chapter != 'Unclassified': | |
| conn.execute('UPDATE questions SET subject = ?, chapter = ? WHERE image_id = ?', (new_subject, new_chapter, matched_id)) | |
| batch_update_count += 1 | |
| elif new_subject and new_subject != 'Unclassified': | |
| conn.execute('UPDATE questions SET subject = ?, chapter = ? WHERE image_id = ?', (new_subject, 'Unclassified', matched_id)) | |
| batch_update_count += 1 | |
| conn.commit() | |
| total_update_count += batch_update_count | |
| current_app.logger.info(f"Batch {i+1} processed. Updated {batch_update_count} questions in the database.") | |
| if i < num_batches - 1: | |
| current_app.logger.info("Waiting 5 seconds before next batch...") | |
| time.sleep(5) | |
| conn.close() | |
| return jsonify({'success': True, 'message': f'Successfully extracted and classified {total_questions} questions. Updated {total_update_count} entries in the database.'}) | |
| except Exception as e: | |
| current_app.logger.error(f'Failed to extract and classify questions: {str(e)}', exc_info=True) | |
| return jsonify({'error': f'Failed to extract and classify questions: {str(e)}'}), 500 | |