from flask import Flask, render_template, request, jsonify import os import tempfile import pandas as pd from werkzeug.utils import secure_filename from datetime import datetime from typing import List, Dict, Any, Optional, Union from pydantic import BaseModel, Field from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, SystemMessage import tiktoken import json from dotenv import load_dotenv import logging import sys # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', stream=sys.stdout ) logger = logging.getLogger(__name__) # Load environment variables load_dotenv() # Define data models class MCQ(BaseModel): question: str options: List[str] correct_answer: str source_name: str = Field(default="Unknown") class Document(BaseModel): name: str = '' content: str mcqs: List[MCQ] = Field(default_factory=list) # Initialize Flask app = Flask(__name__) app.config['UPLOAD_FOLDER'] = tempfile.mkdtemp() app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size ALLOWED_EXTENSIONS = {'txt'} MODELS = ['gpt-4o', 'gpt-4o-mini'] # Load note criteria try: with open('note_criteria.json', 'r') as f: NOTE_CRITERIA = json.load(f)['note_types'] except Exception as e: logger.error(f"Error loading note_criteria.json: {e}") NOTE_CRITERIA = {} def allowed_file(filename): """Check if uploaded file has allowed extension.""" return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def num_tokens_from_messages(messages, model="gpt-4o"): """Estimate token usage for messages.""" try: encoding = tiktoken.get_encoding("cl100k_base") num_tokens = 0 for message in messages: num_tokens += 4 for key, value in message.items(): num_tokens += len(encoding.encode(str(value))) num_tokens += 2 return num_tokens except Exception as e: logger.warning(f"Error counting tokens: {e}") return 0 def generate_mcqs_for_note(note_content: str, total_tokens: List[int], source_name: str = '', document_type: str = 'discharge_note') -> List[MCQ]: """Generate Multiple Choice Questions from medical notes.""" logger.info(f"Starting MCQ generation for {source_name}") try: # Get criteria criteria = NOTE_CRITERIA[document_type]['relevancy_criteria'] criteria_list = "\n".join(f"{i+1}. {criterion}" for i, criterion in enumerate(criteria)) system_prompt = f""" You are an expert in creating MCQs based on medical notes. Generate 20 MCQs that ONLY focus on these key areas: {criteria_list} Rules and Format: 1. Each question must relate to specific content from these areas 2. Skip areas not mentioned in the note 3. Each question must have exactly 5 options (A-D plus E="I don't know") 4. Provide only questions and answers, no explanations 5. Use this exact format: Question: [text] A. [option] B. [option] C. [option] D. [option] E. I don't know Correct Answer: [letter] IMPORTANT: You MUST generate exactly 20 questions. """ logger.debug("Sending request to OpenAI") messages = [ SystemMessage(content=system_prompt), HumanMessage(content=f"Create MCQs from this note:\n\n{note_content}") ] llm = ChatOpenAI( model="gpt-4o-mini", temperature=0, max_tokens=4000 ) response = llm.invoke(messages) logger.debug(f"Received response length: {len(response.content)} chars") # Count tokens tokens_used = num_tokens_from_messages([ {"role": "system", "content": system_prompt}, {"role": "user", "content": note_content}, {"role": "assistant", "content": response.content} ]) total_tokens[0] += tokens_used # Parse MCQs mcqs = [] mcq_texts = response.content.strip().split('\n\n') logger.info(f"Found {len(mcq_texts)} potential MCQ blocks") for mcq_text in mcq_texts: if mcq := parse_mcq(mcq_text): mcq.source_name = source_name mcqs.append(mcq) logger.info(f"Successfully parsed {len(mcqs)} MCQs") return mcqs except Exception as e: logger.exception(f"Error in MCQ generation: {e}") return [] def parse_mcq(mcq_text: str) -> Optional[MCQ]: """Parse a single MCQ from text.""" try: lines = [line.strip() for line in mcq_text.split('\n') if line.strip()] if len(lines) < 7: return None # Extract question if not lines[0].startswith('Question:'): return None question = lines[0].replace('Question:', '', 1).strip() # Extract options options = [] for i, line in enumerate(lines[1:6], 1): if not line.startswith(chr(ord('A') + i - 1) + '.'): return None option = line.split('.', 1)[1].strip() options.append(option) # Extract answer correct_line = lines[6] if not correct_line.lower().startswith('correct answer:'): return None correct_letter = correct_line.split(':', 1)[1].strip().upper() if correct_letter not in 'ABCDE': return None correct_index = ord(correct_letter) - ord('A') correct_answer = options[correct_index] if correct_index < len(options) else options[-1] return MCQ( question=question, options=options, correct_answer=correct_answer ) except Exception as e: logger.error(f"Error parsing MCQ: {e}") return None def present_mcqs_to_content(mcqs: List[MCQ], content: str, total_tokens: List[int]) -> List[Dict]: """Present MCQs to content and collect responses.""" logger.info(f"Presenting {len(mcqs)} MCQs to content") user_responses = [] batch_size = 20 try: llm = ChatOpenAI( model="gpt-4o-mini", temperature=0 ) for i in range(0, len(mcqs), batch_size): batch_mcqs = mcqs[i:i + batch_size] questions_text = "\n\n".join([ f"Question {j+1}: {mcq.question}\n" f"A. {mcq.options[0]}\n" f"B. {mcq.options[1]}\n" f"C. {mcq.options[2]}\n" f"D. {mcq.options[3]}\n" f"E. I don't know" for j, mcq in enumerate(batch_mcqs) ]) prompt = f""" You are an expert medical knowledge evaluator. Given a medical note and multiple questions: 1. For each question, verify if it can be answered from the given content 2. If a question cannot be answered from the content, choose 'E' (I don't know) 3. If a question can be answered, choose the most accurate option based ONLY on the given content Document Content: {content} {questions_text} Respond with ONLY the question numbers and corresponding letters, one per line, like this: 1: A 2: B etc. """ messages = [HumanMessage(content=prompt)] response = llm.invoke(messages) tokens_used = num_tokens_from_messages([ {"role": "user", "content": prompt}, {"role": "assistant", "content": response.content} ]) total_tokens[0] += tokens_used response_lines = response.content.strip().split('\n') for j, line in enumerate(response_lines): if j >= len(batch_mcqs): break mcq = batch_mcqs[j] try: answer_letter = line.split(':')[1].strip().upper() if answer_letter not in ['A', 'B', 'C', 'D', 'E']: answer_letter = 'E' if answer_letter == 'E': user_answer_text = "I don't know" else: option_index = ord(answer_letter) - ord('A') user_answer_text = mcq.options[option_index] except (IndexError, ValueError): user_answer_text = "I don't know" user_responses.append({ "question": mcq.question, "user_answer": user_answer_text, "correct_answer": mcq.correct_answer }) except Exception as e: logger.exception(f"Error in present_mcqs_to_content: {e}") return user_responses @app.route('/') def index(): """Serve the main page.""" return render_template('index.html', models=MODELS) @app.route('/compare', methods=['POST']) def compare_documents(): """Compare two documents endpoint.""" logger.info("Starting document comparison") try: # Validate API key api_key = request.form.get('api_key') if not api_key: logger.error("No API key provided") return jsonify({"error": "OpenAI API key is required"}), 400 os.environ['OPENAI_API_KEY'] = api_key logger.info("API key set") # Get model and document type model = request.form.get('model', 'gpt-4o-mini') document_type = request.form.get('document_type', 'discharge_note') logger.info(f"Using model: {model}") # Validate files if 'doc1' not in request.files or 'doc2' not in request.files: return jsonify({"error": "Both documents are required"}), 400 doc1_file = request.files['doc1'] doc2_file = request.files['doc2'] if not doc1_file.filename or not doc2_file.filename: return jsonify({"error": "No files selected"}), 400 if not all(allowed_file(f.filename) for f in [doc1_file, doc2_file]): return jsonify({"error": "Only .txt files are allowed"}), 400 # Read documents try: doc1_text = doc1_file.read().decode('utf-8') doc2_text = doc2_file.read().decode('utf-8') logger.info(f"Files read successfully. Lengths: {len(doc1_text)}, {len(doc2_text)}") except UnicodeDecodeError: return jsonify({"error": "Invalid file encoding"}), 400 # Initialize token counter and generate MCQs total_tokens = [0] doc1_mcqs = generate_mcqs_for_note(doc1_text, total_tokens, 'Doc1', document_type) doc2_mcqs = generate_mcqs_for_note(doc2_text, total_tokens, 'Doc2', document_type) logger.info(f"Generated MCQs - Doc1: {len(doc1_mcqs)}, Doc2: {len(doc2_mcqs)}") # Get responses doc1_responses = present_mcqs_to_content(doc2_mcqs, doc1_text, total_tokens) doc2_responses = present_mcqs_to_content(doc1_mcqs, doc2_text, total_tokens) # Process results def process_mcq_results(responses, mcqs): attempted = [] unknown = [] correct_count = 0 total_count = len(responses) for i, response in enumerate(responses): if i >= len(mcqs): continue mcq = mcqs[i] answer = response.get("user_answer", "I don't know") result = { "question": mcq.question, "options": mcq.options, "ideal_answer": mcq.correct_answer, "model_answer": answer } if answer == "I don't know": unknown.append(result) else: is_correct = answer == mcq.correct_answer if is_correct: correct_count += 1 result["is_correct"] = is_correct attempted.append(result) return { "score": f"{correct_count}/{total_count}", "attempted_answers": attempted, "unknown_answers": unknown } doc1_analysis = process_mcq_results(doc1_responses, doc2_mcqs) doc2_analysis = process_mcq_results(doc2_responses, doc1_mcqs) # Prepare response response = { "doc1_analysis": doc1_analysis, "doc2_analysis": doc2_analysis, "total_tokens": total_tokens[0], "doc1_content": doc1_text, "doc2_content": doc2_text } logger.info(f"Comparison complete. Total tokens: {total_tokens[0]}") return jsonify(response), 200 except Exception as e: logger.exception("Error in compare_documents:") return jsonify({"error": str(e)}), 500 if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port)