Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import sqlite3 | |
| import os | |
| import pytesseract | |
| from PIL import Image | |
| from pdf2image import convert_from_path | |
| from groq import Groq | |
| import json | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- Configuration --- | |
| DATABASE = "medidoc.db" | |
| UPLOAD_FOLDER = "uploads" | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| # --- Groq Client Initialization --- | |
| # Use environment variable for API key | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY", "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") | |
| client = Groq(api_key=GROQ_API_KEY) | |
| # --- Database Setup --- | |
| def init_db(): | |
| try: | |
| conn = sqlite3.connect(DATABASE) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS documents ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| filename TEXT NOT NULL, | |
| category TEXT, | |
| document_date TEXT, | |
| doctor_name TEXT, | |
| hospital_name TEXT, | |
| summary TEXT, | |
| content TEXT | |
| ) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| logger.info("Database initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Database initialization failed: {e}") | |
| init_db() | |
| # --- FastAPI App --- | |
| app = FastAPI(title="MediDoc API", version="1.0.0") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify exact origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Helper Functions --- | |
| def extract_text_from_file(filepath: str) -> str: | |
| """Extract text from PDF or image files""" | |
| try: | |
| if not os.path.exists(filepath): | |
| logger.error(f"File not found: {filepath}") | |
| return "" | |
| if filepath.lower().endswith(".pdf"): | |
| pages = convert_from_path(filepath) | |
| text = "" | |
| for page in pages: | |
| text += pytesseract.image_to_string(page) + "\n" | |
| return text.strip() | |
| else: | |
| # Handle image files | |
| with Image.open(filepath) as img: | |
| text = pytesseract.image_to_string(img) | |
| return text.strip() | |
| except Exception as e: | |
| logger.error(f"Error extracting text from {filepath}: {e}") | |
| return "" | |
| def process_with_llm(text: str) -> dict: | |
| """Analyze medical text using Groq's Llama model""" | |
| if not text.strip(): | |
| return { | |
| "category": "Empty Document", | |
| "document_date": "N/A", | |
| "doctor_name": "N/A", | |
| "hospital_name": "N/A", | |
| "summary": "Document appears to be empty or text could not be extracted.", | |
| } | |
| system_prompt = """ | |
| You are an expert medical data extraction assistant. Analyze the provided text from a medical document and extract key information. | |
| Respond ONLY with a valid JSON object containing exactly these keys: | |
| - "category": Choose from "Prescription", "Lab Report", "Medical Bill", "Pharmacy Bill", "Discharge Summary", "Consultation Notes", "Other" | |
| - "document_date": Date in YYYY-MM-DD format. If not found, use "N/A" | |
| - "doctor_name": Full name of the doctor. If not found, use "N/A" | |
| - "hospital_name": Name of hospital/clinic. If not found, use "N/A" | |
| - "summary": A brief, clear summary in 1-2 sentences describing what this document is about | |
| Return only the JSON object, no other text. | |
| """ | |
| fallback_response = { | |
| "category": "Other", | |
| "document_date": "N/A", | |
| "doctor_name": "N/A", | |
| "hospital_name": "N/A", | |
| "summary": "Medical document processed but specific information could not be extracted.", | |
| } | |
| try: | |
| completion = client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Medical document text:\n\n{text[:2000]}"} # Limit text length | |
| ], | |
| temperature=0.1, | |
| max_tokens=300, | |
| top_p=1, | |
| stream=False, | |
| ) | |
| response_content = completion.choices[0].message.content.strip() | |
| # Clean up the response | |
| if response_content.startswith("```json"): | |
| response_content = response_content[7:] | |
| if response_content.endswith("```"): | |
| response_content = response_content[:-3] | |
| response_content = response_content.strip() | |
| parsed_response = json.loads(response_content) | |
| # Validate required keys | |
| required_keys = ["category", "document_date", "doctor_name", "hospital_name", "summary"] | |
| for key in required_keys: | |
| if key not in parsed_response: | |
| parsed_response[key] = "N/A" | |
| return parsed_response | |
| except json.JSONDecodeError as e: | |
| logger.error(f"JSON Parsing Error: {e}\nRaw Response: {response_content}") | |
| return fallback_response | |
| except Exception as e: | |
| logger.error(f"Error with Groq API: {e}") | |
| return fallback_response | |
| # --- API Endpoints --- | |
| async def root(): | |
| return {"message": "MediDoc API is running"} | |
| async def upload_document(file: UploadFile = File(...)): | |
| """Upload and process a medical document""" | |
| try: | |
| # Validate file type | |
| allowed_types = ['application/pdf', 'image/jpeg', 'image/jpg', 'image/png'] | |
| if file.content_type not in allowed_types: | |
| raise HTTPException(status_code=400, detail="Only PDF and image files are allowed") | |
| # Save uploaded file | |
| filepath = os.path.join(UPLOAD_FOLDER, file.filename) | |
| with open(filepath, "wb") as buffer: | |
| content = await file.read() | |
| if not content: | |
| raise HTTPException(status_code=400, detail="Uploaded file is empty") | |
| buffer.write(content) | |
| logger.info(f"File saved: {filepath}") | |
| # Extract text | |
| text = extract_text_from_file(filepath) | |
| if not text.strip(): | |
| # Clean up the file | |
| os.remove(filepath) | |
| raise HTTPException(status_code=400, detail="Could not extract text from the uploaded file") | |
| # Process with LLM | |
| processed_data = process_with_llm(text) | |
| # Save to database | |
| conn = sqlite3.connect(DATABASE) | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """INSERT INTO documents | |
| (filename, category, document_date, doctor_name, hospital_name, summary, content) | |
| VALUES (?, ?, ?, ?, ?, ?, ?)""", | |
| ( | |
| file.filename, | |
| processed_data.get("category", "N/A"), | |
| processed_data.get("document_date", "N/A"), | |
| processed_data.get("doctor_name", "N/A"), | |
| processed_data.get("hospital_name", "N/A"), | |
| processed_data.get("summary", "N/A"), | |
| text | |
| ), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| logger.info(f"Document processed successfully: {file.filename}") | |
| return {"filename": file.filename, "info": processed_data, "status": "success"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Unexpected error processing file: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error occurred while processing the file") | |
| def get_documents(): | |
| """Retrieve all processed documents""" | |
| try: | |
| conn = sqlite3.connect(DATABASE) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| SELECT id, filename, category, document_date, doctor_name, hospital_name, summary | |
| FROM documents | |
| ORDER BY | |
| CASE WHEN document_date = 'N/A' THEN 1 ELSE 0 END, | |
| document_date DESC | |
| """) | |
| documents = [dict(row) for row in cursor.fetchall()] | |
| conn.close() | |
| return {"documents": documents, "count": len(documents)} | |
| except Exception as e: | |
| logger.error(f"Error retrieving documents: {e}") | |
| raise HTTPException(status_code=500, detail="Could not retrieve documents") | |
| class SearchResult(BaseModel): | |
| answer: str | |
| sources: list | |
| def search_medical_history(query: str): | |
| """Search through medical documents using natural language""" | |
| if not query.strip(): | |
| raise HTTPException(status_code=400, detail="Search query cannot be empty") | |
| try: | |
| conn = sqlite3.connect(DATABASE) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT filename, content, summary, category FROM documents") | |
| all_docs = cursor.fetchall() | |
| conn.close() | |
| if not all_docs: | |
| return {"answer": "No documents have been uploaded yet. Please upload some medical documents first.", "sources": []} | |
| # Prepare context for the AI | |
| context_parts = [] | |
| for i, doc in enumerate(all_docs): | |
| filename, content, summary, category = doc | |
| context_parts.append(f"Document {i+1}: {filename}\nCategory: {category}\nSummary: {summary}\nContent: {content[:1500]}") | |
| context = "\n\n---\n\n".join(context_parts) | |
| system_prompt = f""" | |
| You are a medical assistant helping a patient understand their medical history. | |
| Answer the user's question based ONLY on the provided medical documents. | |
| Guidelines: | |
| - Provide a clear, helpful answer | |
| - Mention specific document names when referencing information | |
| - If information is not available in the documents, say so clearly | |
| - Be concise but informative | |
| - Use medical terminology appropriately but explain complex terms | |
| Available Documents: | |
| {context} | |
| """ | |
| completion = client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": query} | |
| ], | |
| temperature=0.2, | |
| max_tokens=800, | |
| ) | |
| answer = completion.choices[0].message.content | |
| # Find relevant sources mentioned in the answer | |
| sources = [] | |
| for doc in all_docs: | |
| filename = doc[0] | |
| if filename.lower() in answer.lower(): | |
| sources.append({ | |
| "filename": filename, | |
| "summary": doc[2], | |
| "category": doc[3] | |
| }) | |
| return {"answer": answer, "sources": sources} | |
| except Exception as e: | |
| logger.error(f"Error during search: {e}") | |
| raise HTTPException(status_code=500, detail="Search service is currently unavailable") | |
| def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "database": "connected"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |