Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import json | |
| from PIL import Image | |
| import base64 | |
| from io import BytesIO | |
| import pandas as pd | |
| from datetime import datetime | |
| import time | |
| import logging | |
| import os | |
| from typing import Dict, Any, Optional | |
| import re | |
| from reportlab.lib import colors | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle | |
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
| import io | |
| from dotenv import load_dotenv | |
| import fitz # PyMuPDF for PDF processing | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Configuration and Constants | |
| class Config: | |
| GEMINI_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent" | |
| GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") # Load from .env | |
| MAX_RETRIES = 3 | |
| TIMEOUT = 30 | |
| MAX_IMAGE_SIZE = (1600, 1600) | |
| ALLOWED_MIME_TYPES = ["image/jpeg", "image/png", "application/pdf"] | |
| MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB | |
| # Custom Exceptions | |
| class APIError(Exception): | |
| pass | |
| class ImageProcessingError(Exception): | |
| pass | |
| class PDFProcessingError(Exception): | |
| pass | |
| # Initialize session state | |
| def init_session_state(): | |
| if 'processing_history' not in st.session_state: | |
| st.session_state.processing_history = [] | |
| if 'current_document' not in st.session_state: | |
| st.session_state.current_document = None | |
| if 'pdf_history' not in st.session_state: | |
| st.session_state.pdf_history = [] | |
| # Page setup and styling | |
| def setup_page(): | |
| st.set_page_config( | |
| page_title="Medical Document Processor", | |
| page_icon="🏥", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| st.markdown(""" | |
| <style> | |
| .main {padding: 2rem; max-width: 1200px; margin: 0 auto;} | |
| .stCard { | |
| background-color: white; | |
| padding: 2rem; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
| margin: 1rem 0; | |
| } | |
| .header-container { | |
| background-color: #f8f9fa; | |
| padding: 2rem; | |
| border-radius: 10px; | |
| margin-bottom: 2rem; | |
| } | |
| .stButton>button { | |
| background-color: #007bff; | |
| color: white; | |
| border: none; | |
| padding: 0.5rem 1rem; | |
| border-radius: 5px; | |
| transition: all 0.3s ease; | |
| } | |
| .stButton>button:hover { | |
| background-color: #0056b3; | |
| transform: translateY(-2px); | |
| } | |
| .element-container {opacity: 1 !important;} | |
| .pdf-history-item { | |
| background-color: #f8f9fa; | |
| padding: 1rem; | |
| border-radius: 8px; | |
| margin: 0.5rem 0; | |
| border: 1px solid #dee2e6; | |
| } | |
| .metric-card { | |
| background-color: #f8f9fa; | |
| padding: 1rem; | |
| border-radius: 8px; | |
| border: 1px solid #dee2e6; | |
| margin: 0.5rem 0; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| class PDFGenerator: | |
| def create_pdf(data: Dict[str, Any]) -> bytes: | |
| buffer = io.BytesIO() | |
| doc = SimpleDocTemplate(buffer, pagesize=letter) | |
| styles = getSampleStyleSheet() | |
| elements = [] | |
| # Title | |
| title_style = ParagraphStyle( | |
| 'CustomTitle', | |
| parent=styles['Heading1'], | |
| fontSize=24, | |
| spaceAfter=30 | |
| ) | |
| elements.append(Paragraph("Medical Document Report", title_style)) | |
| elements.append(Spacer(1, 20)) | |
| # Patient Information | |
| elements.append(Paragraph("Patient Information", styles['Heading2'])) | |
| patient_info = data.get('patient_info', {}) | |
| patient_data = [ | |
| ["Name:", patient_info.get('name', 'N/A')], | |
| ["Age:", patient_info.get('age', 'N/A')], | |
| ["Gender:", patient_info.get('gender', 'N/A')] | |
| ] | |
| patient_table = Table(patient_data, colWidths=[100, 300]) | |
| patient_table.setStyle(TableStyle([ | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('PADDING', (0, 0), (-1, -1), 6), | |
| ])) | |
| elements.append(patient_table) | |
| elements.append(Spacer(1, 20)) | |
| # Symptoms | |
| if data.get('symptoms'): | |
| elements.append(Paragraph("Symptoms", styles['Heading2'])) | |
| symptoms_text = "\n".join([f"- {symptom}" for symptom in data['symptoms']]) | |
| elements.append(Paragraph(symptoms_text, styles['BodyText'])) | |
| elements.append(Spacer(1, 20)) | |
| # Vital Signs | |
| if data.get('vital_signs'): | |
| elements.append(Paragraph("Vital Signs", styles['Heading2'])) | |
| vital_signs_data = [["Type", "Value"]] + [[vs['type'], vs['value']] for vs in data['vital_signs']] | |
| vital_signs_table = Table(vital_signs_data, colWidths=[200, 200]) | |
| vital_signs_table.setStyle(TableStyle([ | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('PADDING', (0, 0), (-1, -1), 6), | |
| ])) | |
| elements.append(vital_signs_table) | |
| elements.append(Spacer(1, 20)) | |
| # Medications | |
| if data.get('medications'): | |
| elements.append(Paragraph("Medications", styles['Heading2'])) | |
| meds_data = [["Name", "Dosage", "Instructions"]] + [ | |
| [med['name'], med['dosage'], med['instructions']] for med in data['medications'] | |
| ] | |
| meds_table = Table(meds_data, colWidths=[150, 100, 250]) | |
| meds_table.setStyle(TableStyle([ | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('PADDING', (0, 0), (-1, -1), 6), | |
| ])) | |
| elements.append(meds_table) | |
| elements.append(Spacer(1, 20)) | |
| doc.build(elements) | |
| return buffer.getvalue() | |
| class ImageProcessor: | |
| def validate_file(uploaded_file) -> tuple[bool, str]: | |
| try: | |
| if uploaded_file.size > Config.MAX_FILE_SIZE: | |
| return False, f"File size exceeds {Config.MAX_FILE_SIZE // (1024*1024)}MB limit" | |
| if uploaded_file.type not in Config.ALLOWED_MIME_TYPES: | |
| return False, "Unsupported file type. Please upload JPEG, PNG, or PDF." | |
| return True, "File validation successful" | |
| except Exception as e: | |
| logger.error(f"File validation error: {str(e)}") | |
| return False, f"File validation failed: {str(e)}" | |
| def preprocess_image(image: Image.Image) -> Image.Image: | |
| try: | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| if image.size[0] > Config.MAX_IMAGE_SIZE[0] or image.size[1] > Config.MAX_IMAGE_SIZE[1]: | |
| image.thumbnail(Config.MAX_IMAGE_SIZE, Image.Resampling.LANCZOS) | |
| return image | |
| except Exception as e: | |
| logger.error(f"Image preprocessing error: {str(e)}") | |
| raise ImageProcessingError(f"Failed to preprocess image: {str(e)}") | |
| class DocumentProcessor: | |
| def __init__(self): | |
| self.image_processor = ImageProcessor() | |
| def process_document(self, uploaded_file) -> Dict[str, Any]: | |
| try: | |
| if uploaded_file.type.startswith("image/"): | |
| # Process image | |
| image = Image.open(uploaded_file) | |
| processed_image = self.image_processor.preprocess_image(image) | |
| image_base64 = self.encode_image(processed_image) | |
| extracted_text = self.extract_text(image_base64) | |
| elif uploaded_file.type == "application/pdf": | |
| # Process PDF | |
| extracted_text = self.extract_text_from_pdf(uploaded_file) | |
| else: | |
| raise ValueError("Unsupported file type.") | |
| results = { | |
| "document_type": self.classify_document(extracted_text), | |
| "extracted_text": extracted_text, | |
| "structured_data": None | |
| } | |
| if results["extracted_text"]: | |
| results["structured_data"] = self.extract_structured_data( | |
| results["extracted_text"] | |
| ) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Document processing error: {str(e)}") | |
| raise | |
| def encode_image(image: Image.Image) -> str: | |
| buffered = BytesIO() | |
| image.save(buffered, format="JPEG", quality=95) | |
| return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| def extract_text_from_pdf(uploaded_file) -> str: | |
| try: | |
| pdf_bytes = uploaded_file.read() | |
| pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf") | |
| text = "" | |
| for page_num in range(len(pdf_document)): | |
| page = pdf_document.load_page(page_num) | |
| text += page.get_text() | |
| return text | |
| except Exception as e: | |
| logger.error(f"PDF processing error: {str(e)}") | |
| raise PDFProcessingError(f"Failed to process PDF: {str(e)}") | |
| def classify_document(self, text: str) -> str: | |
| prompt = f""" | |
| Analyze this medical document and classify it into one of the following categories: | |
| - Lab Report | |
| - Patient Chart | |
| - Prescription | |
| - Imaging Report | |
| - Medical Certificate | |
| - Other (specify) | |
| Provide only the category name. | |
| Document Text: | |
| {text} | |
| """ | |
| response = GeminiAPI.call_api(prompt) | |
| return response["candidates"][0]["content"]["parts"][0]["text"].strip() | |
| def extract_text(self, image_base64: str) -> str: | |
| prompt = """ | |
| Extract all visible text from this medical document. | |
| Include: | |
| - Headers and titles | |
| - Patient information | |
| - Medical data and values | |
| - Notes and annotations | |
| - Dates and timestamps | |
| Format the output in a clear, structured manner. | |
| """ | |
| response = GeminiAPI.call_api(prompt, image_base64) | |
| return response["candidates"][0]["content"]["parts"][0]["text"].strip() | |
| def extract_structured_data(self, text: str) -> Dict[str, Any]: | |
| prompt = f""" | |
| Analyze this medical text and return a valid JSON object with the following structure: | |
| {{ | |
| "patient_info": {{ | |
| "name": "string", | |
| "age": "string", | |
| "gender": "string" | |
| }}, | |
| "symptoms": ["string"], | |
| "visits": [ | |
| {{ | |
| "date": "string", | |
| "reason": "string", | |
| "notes": "string" | |
| }} | |
| ], | |
| "vital_signs": [ | |
| {{ | |
| "type": "string", | |
| "value": "string" | |
| }} | |
| ], | |
| "medications": [ | |
| {{ | |
| "name": "string", | |
| "dosage": "string", | |
| "instructions": "string" | |
| }} | |
| ] | |
| }} | |
| Text to analyze: | |
| {text} | |
| """ | |
| response = GeminiAPI.call_api(prompt) | |
| structured_data = self.parse_json_response(response) | |
| # Predict gender if not mentioned | |
| if not structured_data['patient_info'].get('gender'): | |
| structured_data['patient_info']['gender'] = self.predict_gender( | |
| structured_data['patient_info'].get('name', '') | |
| ) | |
| # Correct medicine names | |
| structured_data['medications'] = [ | |
| self.correct_medicine_name(med) for med in structured_data.get('medications', []) | |
| ] | |
| # Improve symptoms extraction | |
| structured_data['symptoms'] = self.extract_symptoms(text) | |
| return structured_data | |
| def predict_gender(name: str) -> str: | |
| """Predict gender based on the patient's name.""" | |
| prompt = f""" | |
| Based on the name '{name}', predict the gender. Return only 'Male' or 'Female'. | |
| """ | |
| response = GeminiAPI.call_api(prompt) | |
| return response["candidates"][0]["content"]["parts"][0]["text"].strip() | |
| def correct_medicine_name(medication: Dict[str, Any]) -> Dict[str, Any]: | |
| """Correct the medicine name using a standardized approach.""" | |
| prompt = f""" | |
| Correct the following medicine name to its standard form: | |
| {medication['name']} | |
| Return only the corrected name. | |
| """ | |
| response = GeminiAPI.call_api(prompt) | |
| medication['name'] = response["candidates"][0]["content"]["parts"][0]["text"].strip() | |
| return medication | |
| def extract_symptoms(text: str) -> list[str]: | |
| """Extract symptoms from the text.""" | |
| prompt = f""" | |
| Extract all symptoms mentioned in the following medical text. Return only a list of symptoms: | |
| {text} | |
| """ | |
| response = GeminiAPI.call_api(prompt) | |
| symptoms = response["candidates"][0]["content"]["parts"][0]["text"].strip().split("\n") | |
| return [symptom.strip() for symptom in symptoms if symptom.strip()] | |
| def parse_json_response(response: Dict[str, Any]) -> Dict[str, Any]: | |
| try: | |
| response_text = response["candidates"][0]["content"]["parts"][0]["text"].strip() | |
| json_match = re.search(r'\{.*\}', response_text, re.DOTALL) | |
| if json_match: | |
| return json.loads(json_match.group()) | |
| raise ValueError("No JSON object found in response") | |
| except Exception as e: | |
| logger.error(f"JSON parsing error: {str(e)}") | |
| raise | |
| class EHRViewer: | |
| def display_ehr(data: Dict[str, Any]): | |
| st.markdown("## 📊 Electronic Health Record") | |
| with st.container(): | |
| st.markdown("### 👤 Patient Information") | |
| cols = st.columns(3) | |
| patient_info = data.get('patient_info', {}) | |
| cols[0].metric("Name", patient_info.get('name', 'N/A')) | |
| cols[1].metric("Age", patient_info.get('age', 'N/A')) | |
| cols[2].metric("Gender", patient_info.get('gender', 'N/A')) | |
| if data.get('symptoms'): | |
| st.markdown("### 🤒 Symptoms") | |
| symptoms_text = "\n".join([f"- {symptom}" for symptom in data['symptoms']]) | |
| st.markdown(symptoms_text) | |
| if data.get('vital_signs'): | |
| st.markdown("### 🫀 Vital Signs") | |
| vital_signs_df = pd.DataFrame(data['vital_signs']) | |
| st.dataframe(vital_signs_df, use_container_width=True) | |
| if data.get('medications'): | |
| st.markdown("### 💊 Medications") | |
| med_df = pd.DataFrame(data['medications']) | |
| st.dataframe(med_df, use_container_width=True) | |
| class GeminiAPI: | |
| def call_api(prompt: str, image_base64: Optional[str] = None) -> Dict[str, Any]: | |
| for attempt in range(Config.MAX_RETRIES): | |
| try: | |
| headers = {"Content-Type": "application/json"} | |
| parts = [{"text": prompt}] | |
| if image_base64: | |
| parts.append({ | |
| "inline_data": { | |
| "mime_type": "image/jpeg", | |
| "data": image_base64 | |
| } | |
| }) | |
| payload = {"contents": [{"parts": parts}]} | |
| response = requests.post( | |
| f"{Config.GEMINI_URL}?key={Config.GEMINI_API_KEY}", | |
| headers=headers, | |
| json=payload, | |
| timeout=Config.TIMEOUT | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| if attempt == Config.MAX_RETRIES - 1: | |
| logger.error(f"API call failed after {Config.MAX_RETRIES} attempts: {str(e)}") | |
| raise APIError(f"API call failed: {str(e)}") | |
| time.sleep(2 ** attempt) | |
| def main(): | |
| init_session_state() | |
| setup_page() | |
| st.title("🏥 Advanced Medical Document Processor") | |
| st.markdown("Upload medical documents (images or PDFs) for automated processing and analysis.") | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("📋 Processing History") | |
| if st.session_state.pdf_history: | |
| for idx, pdf_record in enumerate(st.session_state.pdf_history): | |
| with st.expander(f"Document {idx + 1}: {pdf_record['timestamp']}"): | |
| st.download_button( | |
| "📄 Download PDF", | |
| pdf_record['data'], | |
| file_name=pdf_record['filename'], | |
| mime="application/pdf", | |
| key=f"sidebar_{pdf_record['timestamp']}" | |
| ) | |
| else: | |
| st.info("No documents processed yet") | |
| # Main content | |
| uploaded_file = st.file_uploader( | |
| "Choose a medical document", | |
| type=['png', 'jpg', 'jpeg', 'pdf'], | |
| help="Upload a clear image or PDF of a medical document (max 5MB)" | |
| ) | |
| if uploaded_file: | |
| try: | |
| # Validate file | |
| is_valid, message = ImageProcessor.validate_file(uploaded_file) | |
| if not is_valid: | |
| st.error(message) | |
| return | |
| # Display file | |
| if uploaded_file.type.startswith("image/"): | |
| image = Image.open(uploaded_file) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.image(image, caption="Uploaded Document", use_column_width=True) | |
| elif uploaded_file.type == "application/pdf": | |
| st.info("PDF file uploaded. Processing...") | |
| # Process document | |
| if st.button("🔍 Process Document"): | |
| with st.spinner("Processing document..."): | |
| processor = DocumentProcessor() | |
| results = processor.process_document(uploaded_file) | |
| # Generate PDF | |
| pdf_bytes = PDFGenerator.create_pdf(results['structured_data']) | |
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| pdf_filename = f"medical_report_{timestamp}.pdf" | |
| # Store in session state | |
| st.session_state.current_document = { | |
| 'timestamp': timestamp, | |
| 'results': results | |
| } | |
| st.session_state.processing_history.append( | |
| st.session_state.current_document | |
| ) | |
| st.session_state.pdf_history.append({ | |
| 'timestamp': timestamp, | |
| 'filename': pdf_filename, | |
| 'data': pdf_bytes | |
| }) | |
| # Display results | |
| with col2 if uploaded_file.type.startswith("image/") else st: | |
| st.success("Document processed successfully!") | |
| st.markdown(f"**Document Type:** {results['document_type']}") | |
| with st.expander("View Extracted Text"): | |
| st.text_area( | |
| "Raw Text", | |
| results['extracted_text'], | |
| height=200 | |
| ) | |
| # Display EHR View | |
| if results['structured_data']: | |
| EHRViewer.display_ehr(results['structured_data']) | |
| # Download options | |
| st.markdown("### 📥 Download Options") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| json_str = json.dumps(results['structured_data'], indent=2) | |
| st.download_button( | |
| "⬇️ Download JSON", | |
| json_str, | |
| file_name=f"medical_data_{timestamp}.json", | |
| mime="application/json" | |
| ) | |
| with col2: | |
| st.download_button( | |
| "📄 Download PDF Report", | |
| pdf_bytes, | |
| file_name=pdf_filename, | |
| mime="application/pdf" | |
| ) | |
| # Display PDF History | |
| st.markdown("### 📚 PDF History") | |
| if st.session_state.pdf_history: | |
| for pdf_record in st.session_state.pdf_history: | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| st.write(f"Report from {pdf_record['timestamp']}") | |
| with col2: | |
| st.download_button( | |
| "📄 View PDF", | |
| pdf_record['data'], | |
| file_name=pdf_record['filename'], | |
| mime="application/pdf", | |
| key=f"history_{pdf_record['timestamp']}" | |
| ) | |
| else: | |
| st.info("No PDF history available") | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| logger.exception("Error in main processing loop") | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except Exception as e: | |
| st.error("An unexpected error occurred. Please try again later.") | |
| logger.exception("Unhandled exception in main application") |