dvd_evaluator / app.py
iyadsultan's picture
Refactor app.py to enhance MCQ generation and error handling. Introduced logging for better traceability, improved file upload validation, and updated model names for consistency. Enhanced the parsing logic for MCQs and refined error management during document comparison. Overall, these changes improve code maintainability and robustness.
4ee6059
from flask import Flask, render_template, request, jsonify
import os
import tempfile
import pandas as pd
from werkzeug.utils import secure_filename
from datetime import datetime
from typing import List, Dict, Any, Optional, Union
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
import tiktoken
import json
from dotenv import load_dotenv
import logging
import sys
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Define data models
class MCQ(BaseModel):
question: str
options: List[str]
correct_answer: str
source_name: str = Field(default="Unknown")
class Document(BaseModel):
name: str = ''
content: str
mcqs: List[MCQ] = Field(default_factory=list)
# Initialize Flask
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = tempfile.mkdtemp()
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
ALLOWED_EXTENSIONS = {'txt'}
MODELS = ['gpt-4o', 'gpt-4o-mini']
# Load note criteria
try:
with open('note_criteria.json', 'r') as f:
NOTE_CRITERIA = json.load(f)['note_types']
except Exception as e:
logger.error(f"Error loading note_criteria.json: {e}")
NOTE_CRITERIA = {}
def allowed_file(filename):
"""Check if uploaded file has allowed extension."""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def num_tokens_from_messages(messages, model="gpt-4o"):
"""Estimate token usage for messages."""
try:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
num_tokens += 4
for key, value in message.items():
num_tokens += len(encoding.encode(str(value)))
num_tokens += 2
return num_tokens
except Exception as e:
logger.warning(f"Error counting tokens: {e}")
return 0
def generate_mcqs_for_note(note_content: str, total_tokens: List[int], source_name: str = '', document_type: str = 'discharge_note') -> List[MCQ]:
"""Generate Multiple Choice Questions from medical notes."""
logger.info(f"Starting MCQ generation for {source_name}")
try:
# Get criteria
criteria = NOTE_CRITERIA[document_type]['relevancy_criteria']
criteria_list = "\n".join(f"{i+1}. {criterion}" for i, criterion in enumerate(criteria))
system_prompt = f"""
You are an expert in creating MCQs based on medical notes. Generate 20 MCQs that ONLY focus on these key areas:
{criteria_list}
Rules and Format:
1. Each question must relate to specific content from these areas
2. Skip areas not mentioned in the note
3. Each question must have exactly 5 options (A-D plus E="I don't know")
4. Provide only questions and answers, no explanations
5. Use this exact format:
Question: [text]
A. [option]
B. [option]
C. [option]
D. [option]
E. I don't know
Correct Answer: [letter]
IMPORTANT: You MUST generate exactly 20 questions.
"""
logger.debug("Sending request to OpenAI")
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=f"Create MCQs from this note:\n\n{note_content}")
]
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
max_tokens=4000
)
response = llm.invoke(messages)
logger.debug(f"Received response length: {len(response.content)} chars")
# Count tokens
tokens_used = num_tokens_from_messages([
{"role": "system", "content": system_prompt},
{"role": "user", "content": note_content},
{"role": "assistant", "content": response.content}
])
total_tokens[0] += tokens_used
# Parse MCQs
mcqs = []
mcq_texts = response.content.strip().split('\n\n')
logger.info(f"Found {len(mcq_texts)} potential MCQ blocks")
for mcq_text in mcq_texts:
if mcq := parse_mcq(mcq_text):
mcq.source_name = source_name
mcqs.append(mcq)
logger.info(f"Successfully parsed {len(mcqs)} MCQs")
return mcqs
except Exception as e:
logger.exception(f"Error in MCQ generation: {e}")
return []
def parse_mcq(mcq_text: str) -> Optional[MCQ]:
"""Parse a single MCQ from text."""
try:
lines = [line.strip() for line in mcq_text.split('\n') if line.strip()]
if len(lines) < 7:
return None
# Extract question
if not lines[0].startswith('Question:'):
return None
question = lines[0].replace('Question:', '', 1).strip()
# Extract options
options = []
for i, line in enumerate(lines[1:6], 1):
if not line.startswith(chr(ord('A') + i - 1) + '.'):
return None
option = line.split('.', 1)[1].strip()
options.append(option)
# Extract answer
correct_line = lines[6]
if not correct_line.lower().startswith('correct answer:'):
return None
correct_letter = correct_line.split(':', 1)[1].strip().upper()
if correct_letter not in 'ABCDE':
return None
correct_index = ord(correct_letter) - ord('A')
correct_answer = options[correct_index] if correct_index < len(options) else options[-1]
return MCQ(
question=question,
options=options,
correct_answer=correct_answer
)
except Exception as e:
logger.error(f"Error parsing MCQ: {e}")
return None
def present_mcqs_to_content(mcqs: List[MCQ], content: str, total_tokens: List[int]) -> List[Dict]:
"""Present MCQs to content and collect responses."""
logger.info(f"Presenting {len(mcqs)} MCQs to content")
user_responses = []
batch_size = 20
try:
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0
)
for i in range(0, len(mcqs), batch_size):
batch_mcqs = mcqs[i:i + batch_size]
questions_text = "\n\n".join([
f"Question {j+1}: {mcq.question}\n"
f"A. {mcq.options[0]}\n"
f"B. {mcq.options[1]}\n"
f"C. {mcq.options[2]}\n"
f"D. {mcq.options[3]}\n"
f"E. I don't know"
for j, mcq in enumerate(batch_mcqs)
])
prompt = f"""
You are an expert medical knowledge evaluator. Given a medical note and multiple questions:
1. For each question, verify if it can be answered from the given content
2. If a question cannot be answered from the content, choose 'E' (I don't know)
3. If a question can be answered, choose the most accurate option based ONLY on the given content
Document Content: {content}
{questions_text}
Respond with ONLY the question numbers and corresponding letters, one per line, like this:
1: A
2: B
etc.
"""
messages = [HumanMessage(content=prompt)]
response = llm.invoke(messages)
tokens_used = num_tokens_from_messages([
{"role": "user", "content": prompt},
{"role": "assistant", "content": response.content}
])
total_tokens[0] += tokens_used
response_lines = response.content.strip().split('\n')
for j, line in enumerate(response_lines):
if j >= len(batch_mcqs):
break
mcq = batch_mcqs[j]
try:
answer_letter = line.split(':')[1].strip().upper()
if answer_letter not in ['A', 'B', 'C', 'D', 'E']:
answer_letter = 'E'
if answer_letter == 'E':
user_answer_text = "I don't know"
else:
option_index = ord(answer_letter) - ord('A')
user_answer_text = mcq.options[option_index]
except (IndexError, ValueError):
user_answer_text = "I don't know"
user_responses.append({
"question": mcq.question,
"user_answer": user_answer_text,
"correct_answer": mcq.correct_answer
})
except Exception as e:
logger.exception(f"Error in present_mcqs_to_content: {e}")
return user_responses
@app.route('/')
def index():
"""Serve the main page."""
return render_template('index.html', models=MODELS)
@app.route('/compare', methods=['POST'])
def compare_documents():
"""Compare two documents endpoint."""
logger.info("Starting document comparison")
try:
# Validate API key
api_key = request.form.get('api_key')
if not api_key:
logger.error("No API key provided")
return jsonify({"error": "OpenAI API key is required"}), 400
os.environ['OPENAI_API_KEY'] = api_key
logger.info("API key set")
# Get model and document type
model = request.form.get('model', 'gpt-4o-mini')
document_type = request.form.get('document_type', 'discharge_note')
logger.info(f"Using model: {model}")
# Validate files
if 'doc1' not in request.files or 'doc2' not in request.files:
return jsonify({"error": "Both documents are required"}), 400
doc1_file = request.files['doc1']
doc2_file = request.files['doc2']
if not doc1_file.filename or not doc2_file.filename:
return jsonify({"error": "No files selected"}), 400
if not all(allowed_file(f.filename) for f in [doc1_file, doc2_file]):
return jsonify({"error": "Only .txt files are allowed"}), 400
# Read documents
try:
doc1_text = doc1_file.read().decode('utf-8')
doc2_text = doc2_file.read().decode('utf-8')
logger.info(f"Files read successfully. Lengths: {len(doc1_text)}, {len(doc2_text)}")
except UnicodeDecodeError:
return jsonify({"error": "Invalid file encoding"}), 400
# Initialize token counter and generate MCQs
total_tokens = [0]
doc1_mcqs = generate_mcqs_for_note(doc1_text, total_tokens, 'Doc1', document_type)
doc2_mcqs = generate_mcqs_for_note(doc2_text, total_tokens, 'Doc2', document_type)
logger.info(f"Generated MCQs - Doc1: {len(doc1_mcqs)}, Doc2: {len(doc2_mcqs)}")
# Get responses
doc1_responses = present_mcqs_to_content(doc2_mcqs, doc1_text, total_tokens)
doc2_responses = present_mcqs_to_content(doc1_mcqs, doc2_text, total_tokens)
# Process results
def process_mcq_results(responses, mcqs):
attempted = []
unknown = []
correct_count = 0
total_count = len(responses)
for i, response in enumerate(responses):
if i >= len(mcqs):
continue
mcq = mcqs[i]
answer = response.get("user_answer", "I don't know")
result = {
"question": mcq.question,
"options": mcq.options,
"ideal_answer": mcq.correct_answer,
"model_answer": answer
}
if answer == "I don't know":
unknown.append(result)
else:
is_correct = answer == mcq.correct_answer
if is_correct:
correct_count += 1
result["is_correct"] = is_correct
attempted.append(result)
return {
"score": f"{correct_count}/{total_count}",
"attempted_answers": attempted,
"unknown_answers": unknown
}
doc1_analysis = process_mcq_results(doc1_responses, doc2_mcqs)
doc2_analysis = process_mcq_results(doc2_responses, doc1_mcqs)
# Prepare response
response = {
"doc1_analysis": doc1_analysis,
"doc2_analysis": doc2_analysis,
"total_tokens": total_tokens[0],
"doc1_content": doc1_text,
"doc2_content": doc2_text
}
logger.info(f"Comparison complete. Total tokens: {total_tokens[0]}")
return jsonify(response), 200
except Exception as e:
logger.exception("Error in compare_documents:")
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)