Spaces:
Sleeping
Sleeping
Refactor app.py to enhance MCQ generation and error handling. Introduced logging for better traceability, improved file upload validation, and updated model names for consistency. Enhanced the parsing logic for MCQs and refined error management during document comparison. Overall, these changes improve code maintainability and robustness.
4ee6059 | from flask import Flask, render_template, request, jsonify | |
| import os | |
| import tempfile | |
| import pandas as pd | |
| from werkzeug.utils import secure_filename | |
| from datetime import datetime | |
| from typing import List, Dict, Any, Optional, Union | |
| from pydantic import BaseModel, Field | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| import tiktoken | |
| import json | |
| from dotenv import load_dotenv | |
| import logging | |
| import sys | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| stream=sys.stdout | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| # Define data models | |
| class MCQ(BaseModel): | |
| question: str | |
| options: List[str] | |
| correct_answer: str | |
| source_name: str = Field(default="Unknown") | |
| class Document(BaseModel): | |
| name: str = '' | |
| content: str | |
| mcqs: List[MCQ] = Field(default_factory=list) | |
| # Initialize Flask | |
| app = Flask(__name__) | |
| app.config['UPLOAD_FOLDER'] = tempfile.mkdtemp() | |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size | |
| ALLOWED_EXTENSIONS = {'txt'} | |
| MODELS = ['gpt-4o', 'gpt-4o-mini'] | |
| # Load note criteria | |
| try: | |
| with open('note_criteria.json', 'r') as f: | |
| NOTE_CRITERIA = json.load(f)['note_types'] | |
| except Exception as e: | |
| logger.error(f"Error loading note_criteria.json: {e}") | |
| NOTE_CRITERIA = {} | |
| def allowed_file(filename): | |
| """Check if uploaded file has allowed extension.""" | |
| return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
| def num_tokens_from_messages(messages, model="gpt-4o"): | |
| """Estimate token usage for messages.""" | |
| try: | |
| encoding = tiktoken.get_encoding("cl100k_base") | |
| num_tokens = 0 | |
| for message in messages: | |
| num_tokens += 4 | |
| for key, value in message.items(): | |
| num_tokens += len(encoding.encode(str(value))) | |
| num_tokens += 2 | |
| return num_tokens | |
| except Exception as e: | |
| logger.warning(f"Error counting tokens: {e}") | |
| return 0 | |
| def generate_mcqs_for_note(note_content: str, total_tokens: List[int], source_name: str = '', document_type: str = 'discharge_note') -> List[MCQ]: | |
| """Generate Multiple Choice Questions from medical notes.""" | |
| logger.info(f"Starting MCQ generation for {source_name}") | |
| try: | |
| # Get criteria | |
| criteria = NOTE_CRITERIA[document_type]['relevancy_criteria'] | |
| criteria_list = "\n".join(f"{i+1}. {criterion}" for i, criterion in enumerate(criteria)) | |
| system_prompt = f""" | |
| You are an expert in creating MCQs based on medical notes. Generate 20 MCQs that ONLY focus on these key areas: | |
| {criteria_list} | |
| Rules and Format: | |
| 1. Each question must relate to specific content from these areas | |
| 2. Skip areas not mentioned in the note | |
| 3. Each question must have exactly 5 options (A-D plus E="I don't know") | |
| 4. Provide only questions and answers, no explanations | |
| 5. Use this exact format: | |
| Question: [text] | |
| A. [option] | |
| B. [option] | |
| C. [option] | |
| D. [option] | |
| E. I don't know | |
| Correct Answer: [letter] | |
| IMPORTANT: You MUST generate exactly 20 questions. | |
| """ | |
| logger.debug("Sending request to OpenAI") | |
| messages = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=f"Create MCQs from this note:\n\n{note_content}") | |
| ] | |
| llm = ChatOpenAI( | |
| model="gpt-4o-mini", | |
| temperature=0, | |
| max_tokens=4000 | |
| ) | |
| response = llm.invoke(messages) | |
| logger.debug(f"Received response length: {len(response.content)} chars") | |
| # Count tokens | |
| tokens_used = num_tokens_from_messages([ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": note_content}, | |
| {"role": "assistant", "content": response.content} | |
| ]) | |
| total_tokens[0] += tokens_used | |
| # Parse MCQs | |
| mcqs = [] | |
| mcq_texts = response.content.strip().split('\n\n') | |
| logger.info(f"Found {len(mcq_texts)} potential MCQ blocks") | |
| for mcq_text in mcq_texts: | |
| if mcq := parse_mcq(mcq_text): | |
| mcq.source_name = source_name | |
| mcqs.append(mcq) | |
| logger.info(f"Successfully parsed {len(mcqs)} MCQs") | |
| return mcqs | |
| except Exception as e: | |
| logger.exception(f"Error in MCQ generation: {e}") | |
| return [] | |
| def parse_mcq(mcq_text: str) -> Optional[MCQ]: | |
| """Parse a single MCQ from text.""" | |
| try: | |
| lines = [line.strip() for line in mcq_text.split('\n') if line.strip()] | |
| if len(lines) < 7: | |
| return None | |
| # Extract question | |
| if not lines[0].startswith('Question:'): | |
| return None | |
| question = lines[0].replace('Question:', '', 1).strip() | |
| # Extract options | |
| options = [] | |
| for i, line in enumerate(lines[1:6], 1): | |
| if not line.startswith(chr(ord('A') + i - 1) + '.'): | |
| return None | |
| option = line.split('.', 1)[1].strip() | |
| options.append(option) | |
| # Extract answer | |
| correct_line = lines[6] | |
| if not correct_line.lower().startswith('correct answer:'): | |
| return None | |
| correct_letter = correct_line.split(':', 1)[1].strip().upper() | |
| if correct_letter not in 'ABCDE': | |
| return None | |
| correct_index = ord(correct_letter) - ord('A') | |
| correct_answer = options[correct_index] if correct_index < len(options) else options[-1] | |
| return MCQ( | |
| question=question, | |
| options=options, | |
| correct_answer=correct_answer | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error parsing MCQ: {e}") | |
| return None | |
| def present_mcqs_to_content(mcqs: List[MCQ], content: str, total_tokens: List[int]) -> List[Dict]: | |
| """Present MCQs to content and collect responses.""" | |
| logger.info(f"Presenting {len(mcqs)} MCQs to content") | |
| user_responses = [] | |
| batch_size = 20 | |
| try: | |
| llm = ChatOpenAI( | |
| model="gpt-4o-mini", | |
| temperature=0 | |
| ) | |
| for i in range(0, len(mcqs), batch_size): | |
| batch_mcqs = mcqs[i:i + batch_size] | |
| questions_text = "\n\n".join([ | |
| f"Question {j+1}: {mcq.question}\n" | |
| f"A. {mcq.options[0]}\n" | |
| f"B. {mcq.options[1]}\n" | |
| f"C. {mcq.options[2]}\n" | |
| f"D. {mcq.options[3]}\n" | |
| f"E. I don't know" | |
| for j, mcq in enumerate(batch_mcqs) | |
| ]) | |
| prompt = f""" | |
| You are an expert medical knowledge evaluator. Given a medical note and multiple questions: | |
| 1. For each question, verify if it can be answered from the given content | |
| 2. If a question cannot be answered from the content, choose 'E' (I don't know) | |
| 3. If a question can be answered, choose the most accurate option based ONLY on the given content | |
| Document Content: {content} | |
| {questions_text} | |
| Respond with ONLY the question numbers and corresponding letters, one per line, like this: | |
| 1: A | |
| 2: B | |
| etc. | |
| """ | |
| messages = [HumanMessage(content=prompt)] | |
| response = llm.invoke(messages) | |
| tokens_used = num_tokens_from_messages([ | |
| {"role": "user", "content": prompt}, | |
| {"role": "assistant", "content": response.content} | |
| ]) | |
| total_tokens[0] += tokens_used | |
| response_lines = response.content.strip().split('\n') | |
| for j, line in enumerate(response_lines): | |
| if j >= len(batch_mcqs): | |
| break | |
| mcq = batch_mcqs[j] | |
| try: | |
| answer_letter = line.split(':')[1].strip().upper() | |
| if answer_letter not in ['A', 'B', 'C', 'D', 'E']: | |
| answer_letter = 'E' | |
| if answer_letter == 'E': | |
| user_answer_text = "I don't know" | |
| else: | |
| option_index = ord(answer_letter) - ord('A') | |
| user_answer_text = mcq.options[option_index] | |
| except (IndexError, ValueError): | |
| user_answer_text = "I don't know" | |
| user_responses.append({ | |
| "question": mcq.question, | |
| "user_answer": user_answer_text, | |
| "correct_answer": mcq.correct_answer | |
| }) | |
| except Exception as e: | |
| logger.exception(f"Error in present_mcqs_to_content: {e}") | |
| return user_responses | |
| def index(): | |
| """Serve the main page.""" | |
| return render_template('index.html', models=MODELS) | |
| def compare_documents(): | |
| """Compare two documents endpoint.""" | |
| logger.info("Starting document comparison") | |
| try: | |
| # Validate API key | |
| api_key = request.form.get('api_key') | |
| if not api_key: | |
| logger.error("No API key provided") | |
| return jsonify({"error": "OpenAI API key is required"}), 400 | |
| os.environ['OPENAI_API_KEY'] = api_key | |
| logger.info("API key set") | |
| # Get model and document type | |
| model = request.form.get('model', 'gpt-4o-mini') | |
| document_type = request.form.get('document_type', 'discharge_note') | |
| logger.info(f"Using model: {model}") | |
| # Validate files | |
| if 'doc1' not in request.files or 'doc2' not in request.files: | |
| return jsonify({"error": "Both documents are required"}), 400 | |
| doc1_file = request.files['doc1'] | |
| doc2_file = request.files['doc2'] | |
| if not doc1_file.filename or not doc2_file.filename: | |
| return jsonify({"error": "No files selected"}), 400 | |
| if not all(allowed_file(f.filename) for f in [doc1_file, doc2_file]): | |
| return jsonify({"error": "Only .txt files are allowed"}), 400 | |
| # Read documents | |
| try: | |
| doc1_text = doc1_file.read().decode('utf-8') | |
| doc2_text = doc2_file.read().decode('utf-8') | |
| logger.info(f"Files read successfully. Lengths: {len(doc1_text)}, {len(doc2_text)}") | |
| except UnicodeDecodeError: | |
| return jsonify({"error": "Invalid file encoding"}), 400 | |
| # Initialize token counter and generate MCQs | |
| total_tokens = [0] | |
| doc1_mcqs = generate_mcqs_for_note(doc1_text, total_tokens, 'Doc1', document_type) | |
| doc2_mcqs = generate_mcqs_for_note(doc2_text, total_tokens, 'Doc2', document_type) | |
| logger.info(f"Generated MCQs - Doc1: {len(doc1_mcqs)}, Doc2: {len(doc2_mcqs)}") | |
| # Get responses | |
| doc1_responses = present_mcqs_to_content(doc2_mcqs, doc1_text, total_tokens) | |
| doc2_responses = present_mcqs_to_content(doc1_mcqs, doc2_text, total_tokens) | |
| # Process results | |
| def process_mcq_results(responses, mcqs): | |
| attempted = [] | |
| unknown = [] | |
| correct_count = 0 | |
| total_count = len(responses) | |
| for i, response in enumerate(responses): | |
| if i >= len(mcqs): | |
| continue | |
| mcq = mcqs[i] | |
| answer = response.get("user_answer", "I don't know") | |
| result = { | |
| "question": mcq.question, | |
| "options": mcq.options, | |
| "ideal_answer": mcq.correct_answer, | |
| "model_answer": answer | |
| } | |
| if answer == "I don't know": | |
| unknown.append(result) | |
| else: | |
| is_correct = answer == mcq.correct_answer | |
| if is_correct: | |
| correct_count += 1 | |
| result["is_correct"] = is_correct | |
| attempted.append(result) | |
| return { | |
| "score": f"{correct_count}/{total_count}", | |
| "attempted_answers": attempted, | |
| "unknown_answers": unknown | |
| } | |
| doc1_analysis = process_mcq_results(doc1_responses, doc2_mcqs) | |
| doc2_analysis = process_mcq_results(doc2_responses, doc1_mcqs) | |
| # Prepare response | |
| response = { | |
| "doc1_analysis": doc1_analysis, | |
| "doc2_analysis": doc2_analysis, | |
| "total_tokens": total_tokens[0], | |
| "doc1_content": doc1_text, | |
| "doc2_content": doc2_text | |
| } | |
| logger.info(f"Comparison complete. Total tokens: {total_tokens[0]}") | |
| return jsonify(response), 200 | |
| except Exception as e: | |
| logger.exception("Error in compare_documents:") | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port) |