""" Medical Research API Server - HuggingFace Compatible Version """ # Add this for HuggingFace compatibility import os import sys # Set HuggingFace specific settings if os.getenv('SPACE_ID'): # We're on HuggingFace Spaces print(f"๐Ÿš€ Running on HuggingFace Space: {os.getenv('SPACE_ID')}") # Force port 7860 for HuggingFace os.environ['PORT'] = '7860' # Update CORS for HuggingFace ALLOWED_ORIGINS = [ "https://medical-research-ai.vercel.app", "http://localhost:3000", "https://paulhemb-medsearchpro.hf.space" ] else: ALLOWED_ORIGINS = ["*"] # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Import engine - Vercel compatible try: # Try relative import first (Vercel runs from api/ directory) from engine import MedicalResearchEngine except ImportError: try: # Try absolute import for local development from api.engine import MedicalResearchEngine except ImportError: print("โš ๏ธ MedicalResearchEngine not found - using fallback mode") # Fallback class class MedicalResearchEngine: def __init__(self): pass async def process_query_async(self, **kwargs): return {"answer": "Engine not available", "papers_used": 0} # ============================================================================ # DOMAIN AND USER CONTEXT DEFINITIONS # ============================================================================ MEDICAL_DOMAINS = [ {"id": "internal_medicine", "name": "Internal Medicine", "icon": "๐Ÿฅ", "description": "General internal medicine and diagnosis"}, {"id": "endocrinology", "name": "Endocrinology", "icon": "๐Ÿงฌ", "description": "Hormonal and metabolic disorders"}, {"id": "gastroenterology", "name": "Gastroenterology", "icon": "๐Ÿฉธ", "description": "Digestive system disorders"}, {"id": "pulmonology", "name": "Pulmonology", "icon": "๐Ÿซ", "description": "Respiratory diseases and lung disorders"}, {"id": "nephrology", "name": "Nephrology", "icon": "๐Ÿงช", "description": "Kidney diseases and renal function"}, {"id": "hematology", "name": "Hematology", "icon": "๐Ÿฉธ", "description": "Blood disorders and hematologic diseases"}, {"id": "infectious_disease", "name": "Infectious Diseases", "icon": "๐Ÿฆ ", "description": "Infectious diseases and microbiology"}, {"id": "obstetrics_gynecology", "name": "Obstetrics & Gynecology", "icon": "๐Ÿคฐ", "description": "Women's health, pregnancy and reproductive medicine"}, {"id": "pathology", "name": "Pathology", "icon": "๐Ÿ”ฌ", "description": "Disease diagnosis through tissue examination"}, {"id": "laboratory_medicine", "name": "Laboratory Medicine", "icon": "๐Ÿงช", "description": "Clinical laboratory testing and biomarkers"}, {"id": "bioinformatics", "name": "Bioinformatics", "icon": "๐Ÿ’ป", "description": "Computational analysis of biological data"}, {"id": "clinical_research", "name": "Clinical Research", "icon": "๐Ÿ“Š", "description": "Clinical trials and evidence-based medicine"}, {"id": "medical_imaging", "name": "Medical Imaging", "icon": "๐Ÿฉป", "description": "Medical imaging and radiology"}, {"id": "oncology", "name": "Oncology", "icon": "๐Ÿฆ ", "description": "Cancer research and treatment"}, {"id": "cardiology", "name": "Cardiology", "icon": "โค๏ธ", "description": "Heart and cardiovascular diseases"}, {"id": "neurology", "name": "Neurology", "icon": "๐Ÿง ", "description": "Brain and nervous system disorders"}, {"id": "pharmacology", "name": "Pharmacology", "icon": "๐Ÿ’Š", "description": "Drug therapy and medication management"}, {"id": "genomics", "name": "Genomics", "icon": "๐Ÿงฌ", "description": "Genetic research and personalized medicine"}, {"id": "public_health", "name": "Public Health", "icon": "๐ŸŒ", "description": "Population health and epidemiology"}, {"id": "surgery", "name": "Surgery", "icon": "โš•๏ธ", "description": "Surgical procedures and techniques"}, {"id": "pediatrics", "name": "Pediatrics", "icon": "๐Ÿ‘ถ", "description": "Child health and pediatric medicine"}, {"id": "psychiatry", "name": "Psychiatry", "icon": "๐Ÿง ", "description": "Mental health and psychiatric disorders"}, {"id": "dermatology", "name": "Dermatology", "icon": "๐Ÿฆ‹", "description": "Skin diseases and dermatologic conditions"}, {"id": "orthopedics", "name": "Orthopedics", "icon": "๐Ÿฆด", "description": "Musculoskeletal disorders and bone health"}, {"id": "ophthalmology", "name": "Ophthalmology", "icon": "๐Ÿ‘๏ธ", "description": "Eye diseases and vision care"}, {"id": "urology", "name": "Urology", "icon": "๐Ÿ’ง", "description": "Urinary system and male reproductive health"}, {"id": "emergency_medicine", "name": "Emergency Medicine", "icon": "๐Ÿš‘", "description": "Acute care and emergency response"}, {"id": "critical_care", "name": "Critical Care", "icon": "๐Ÿฅ", "description": "Intensive care and critical illness"}, {"id": "pain_medicine", "name": "Pain Medicine", "icon": "โš•๏ธ", "description": "Pain management and analgesia"}, {"id": "nutrition", "name": "Nutrition", "icon": "๐Ÿฅ—", "description": "Clinical nutrition and dietary management"}, {"id": "allergy_immunology", "name": "Allergy & Immunology", "icon": "๐Ÿคง", "description": "Allergic diseases and immune disorders"}, {"id": "rehabilitation_medicine", "name": "Rehabilitation Medicine", "icon": "โ™ฟ", "description": "Physical therapy and recovery"}, {"id": "general_medical", "name": "General Medical", "icon": "โš•๏ธ", "description": "General medical research and clinical questions"}, {"id": "auto", "name": "Auto-detect", "icon": "๐Ÿค–", "description": "Automatically detect domain from query"} ] USER_CONTEXTS = [ {"id": "auto", "name": "Auto-detect", "icon": "๐Ÿค–", "description": "Automatically detect user context"}, {"id": "clinician", "name": "Clinician", "icon": "๐Ÿ‘จโ€โš•๏ธ", "description": "Medical doctors, nurses, and healthcare providers"}, {"id": "researcher", "name": "Researcher", "icon": "๐Ÿ”ฌ", "description": "Academic researchers and scientists"}, {"id": "student", "name": "Student", "icon": "๐ŸŽ“", "description": "Medical students and trainees"}, {"id": "administrator", "name": "Administrator", "icon": "๐Ÿ’ผ", "description": "Healthcare administrators and managers"}, {"id": "patient", "name": "Patient", "icon": "๐Ÿ‘ค", "description": "Patients and general public"}, {"id": "general", "name": "General", "icon": "๐Ÿ‘ค", "description": "General audience"} ] VALID_DOMAINS: Set[str] = {domain["id"] for domain in MEDICAL_DOMAINS} VALID_USER_CONTEXTS: Set[str] = {context["id"] for context in USER_CONTEXTS} # ============================================================================ # PYDANTIC MODELS # ============================================================================ class SessionCreate(BaseModel): """Schema for creating a new session""" session_id: Optional[str] = None user_context: str = "auto" class ChatRequest(BaseModel): """Schema for chat request""" message: str session_id: str domain: Optional[str] = "general_medical" user_context: str = "auto" max_papers: int = 15 class ChatResponse(BaseModel): """Schema for chat response""" success: bool message: str session_id: str processing_time: Optional[float] = None confidence_score: Optional[float] = None papers_used: Optional[int] = None user_context: Optional[str] = None raw_response: Optional[Dict] = None error: Optional[str] = None # ============================================================================ # FASTAPI APP INITIALIZATION # ============================================================================ app = FastAPI( title="Medical Research AI", description="Medical Research Assistant with Evidence-Based Analysis", version="1.0.0", docs_url="/api/docs", redoc_url="/api/redoc" ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["ALLOWED_ORIGINS"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mount static files (only if directory exists) static_dir = "static" if os.path.exists(static_dir): app.mount("/static", StaticFiles(directory=static_dir), name="static") else: logger.warning(f"Static directory '{static_dir}' not found") # Templates (only if directory exists) templates_dir = "templates" if os.path.exists(templates_dir): templates = Jinja2Templates(directory=templates_dir) else: templates = None logger.warning(f"Templates directory '{templates_dir}' not found") # Initialize chat engine chat_engine = MedicalResearchEngine() # Active WebSocket connections active_connections: Dict[str, WebSocket] = {} # Session storage user_sessions: Dict[str, Dict] = {} # ============================================================================ # HELPER FUNCTIONS # ============================================================================ def validate_domain(domain: str) -> str: """Validate and normalize domain""" if domain not in VALID_DOMAINS: logger.warning(f"Invalid domain '{domain}', defaulting to 'general_medical'") return "general_medical" return domain def validate_user_context(user_context: str) -> str: """Validate and normalize user context""" if user_context not in VALID_USER_CONTEXTS: logger.warning(f"Invalid user_context '{user_context}', defaulting to 'auto'") return "auto" return user_context def get_domain_by_id(domain_id: str) -> Optional[Dict]: """Get domain info by ID""" for domain in MEDICAL_DOMAINS: if domain["id"] == domain_id: return domain return None def get_user_context_by_id(context_id: str) -> Optional[Dict]: """Get user context info by ID""" for context in USER_CONTEXTS: if context["id"] == context_id: return context return None def split_into_chunks(text: str, chunk_size: int = 200) -> List[str]: """Split text into chunks for streaming""" return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] # ============================================================================ # ROUTES # ============================================================================ @app.get("/", response_class=HTMLResponse) async def home(request: Request): """Serve the chat interface""" if templates: return templates.TemplateResponse("index.html", {"request": request}) # Fallback HTML if templates not found html_content = """ Medical Research AI

๐Ÿฅ Medical Research AI

Evidence-based medical research assistant with AI-powered insights

34
Medical Domains
7
User Contexts
API
Ready

๐Ÿ“š API Documentation

๐Ÿ“– OpenAPI/Swagger Documentation ๐Ÿ“„ ReDoc Documentation

๐Ÿ”ง API Endpoints

โค๏ธ Health Check ๐Ÿฅ Available Medical Domains ๐Ÿ‘ค User Contexts

๐Ÿš€ Quick Start

View API Docs GitHub

๐Ÿš€ Deployed on Vercel | โšก FastAPI | ๐Ÿงฌ Medical AI

""" return HTMLResponse(content=html_content) @app.get("/api/health") async def health_check(): """Health check endpoint""" engine_status = chat_engine.get_engine_status() if hasattr(chat_engine, 'get_engine_status') else {} return { "status": "healthy", "engine": "Medical Research Engine", "version": "1.0.0", "timestamp": datetime.now().isoformat(), "engine_configured": chat_engine.api_configured if hasattr(chat_engine, 'api_configured') else False, "features": [ "Evidence-Based Medical Analysis", "Domain-Specific Research", "User Context Adaptation", "Paper Summarization" ], "stats": { "domains_count": len(MEDICAL_DOMAINS), "user_contexts_count": len(USER_CONTEXTS), "active_sessions": len(user_sessions), "active_connections": len(active_connections) } } @app.get("/api/v1/domains") async def get_domains(): """Get all available medical domains""" return { "success": True, "domains": MEDICAL_DOMAINS, "count": len(MEDICAL_DOMAINS), "timestamp": datetime.now().isoformat() } @app.get("/api/v1/domains/{domain_id}") async def get_domain_info(domain_id: str): """Get information about a specific domain""" domain = get_domain_by_id(domain_id) if not domain: return JSONResponse( status_code=404, content={"error": f"Domain '{domain_id}' not found"} ) return { "success": True, "domain": domain, "timestamp": datetime.now().isoformat() } @app.get("/api/v1/user_contexts") async def get_user_contexts(): """Get all available user contexts""" return { "success": True, "user_contexts": USER_CONTEXTS, "count": len(USER_CONTEXTS), "timestamp": datetime.now().isoformat() } @app.post("/api/v1/session/create") async def create_session(request: SessionCreate = None): """Create a new chat session""" if request is None: request = SessionCreate() session_id = request.session_id or str(uuid.uuid4()) user_context = validate_user_context(request.user_context) user_sessions[session_id] = { "id": session_id, "created_at": datetime.now().isoformat(), "user_context": user_context, "message_count": 0, "domains_used": set(), "last_active": datetime.now().isoformat() } # Initialize engine for this session if hasattr(chat_engine, 'initialize_session'): chat_engine.initialize_session(session_id) context_info = get_user_context_by_id(user_context) return { "session_id": session_id, "user_context": user_context, "context_info": context_info, "created_at": user_sessions[session_id]["created_at"], "welcome_message": f"""๐ŸŽ‰ **Welcome to Medical Research Assistant!** ๐Ÿงฌ ๐Ÿ‘ค **Your session context:** {context_info['name'] if context_info else user_context} ๐Ÿฅ **Available Specialties:** {len(MEDICAL_DOMAINS) - 2} medical domains I can help you with: โ€ข **Medical Research Analysis** - Evidence-based insights โ€ข **Paper Summarization** - Key findings and implications โ€ข **Clinical Context** - Domain-specific applications โ€ข **Research Gap Identification** - Future directions **Try asking:** โ€ข "Latest treatments for diabetes" (Endocrinology) โ€ข "Research on Alzheimer's biomarkers" (Neurology) โ€ข "Clinical guidelines for hypertension" (Cardiology) โ€ข "Summarize recent advances in cancer immunotherapy" (Oncology) I'll adapt my responses based on your role and medical domain.""" } @app.post("/api/v1/chat") async def chat_endpoint(request: ChatRequest): """Process chat message""" try: # Validate inputs domain = validate_domain(request.domain) user_context = validate_user_context(request.user_context) # Validate max_papers if request.max_papers < 1 or request.max_papers > 50: request.max_papers = min(max(request.max_papers, 1), 50) # Update session activity if request.session_id in user_sessions: session = user_sessions[request.session_id] session["last_active"] = datetime.now().isoformat() session["message_count"] += 1 session["domains_used"].add(domain) # Use session user_context if available if session.get("user_context"): user_context = session["user_context"] else: session["user_context"] = user_context logger.info(f"Processing chat - Domain: {domain}, Context: {user_context}") # Process the query start_time = datetime.now() response = await chat_engine.process_query_async( query=request.message, domain=domain, session_id=request.session_id, user_context=user_context, max_papers=request.max_papers ) processing_time = (datetime.now() - start_time).total_seconds() # Track query type if request.session_id in user_sessions: query_type = response.get("query_type", "unknown") if "query_types" not in user_sessions[request.session_id]: user_sessions[request.session_id]["query_types"] = [] user_sessions[request.session_id]["query_types"].append(query_type) # Format response return ChatResponse( success=True, message=response.get("answer", "No response generated"), session_id=request.session_id, processing_time=processing_time, confidence_score=response.get("confidence_score", {}).get("overall_score", 0), papers_used=response.get("papers_used", 0), user_context=response.get("user_context", user_context), raw_response=response ) except Exception as e: logger.error(f"Chat endpoint error: {str(e)}", exc_info=True) return ChatResponse( success=False, message=f"โŒ Error: {str(e)}", session_id=request.session_id, error=str(e), user_context=request.user_context ) @app.websocket("/ws/chat") async def websocket_chat(websocket: WebSocket): """WebSocket for real-time chat""" await websocket.accept() session_id = None user_context = "auto" try: while True: # Receive message data = await websocket.receive_json() message_type = data.get("type") if message_type == "init_session": # Create or get session session_id = data.get("session_id") or str(uuid.uuid4()) user_context = validate_user_context(data.get("user_context", "auto")) if session_id not in user_sessions: user_sessions[session_id] = { "id": session_id, "created_at": datetime.now().isoformat(), "user_context": user_context, "message_count": 0, "websocket": websocket } if hasattr(chat_engine, 'initialize_session'): chat_engine.initialize_session(session_id) active_connections[session_id] = websocket await websocket.send_json({ "type": "session_created", "session_id": session_id, "user_context": user_context, "timestamp": datetime.now().isoformat(), "features": [ "medical_research_analysis", "domain_specific_insights", "user_context_adaptation" ], "stats": { "domains_available": len(MEDICAL_DOMAINS), "user_contexts_available": len(USER_CONTEXTS) } }) elif message_type == "message" and session_id: # Process chat message user_message = data.get("message", "") domain = validate_domain(data.get("domain", "general_medical")) user_context = validate_user_context(data.get("user_context", user_context)) # Update session context if session_id in user_sessions: user_sessions[session_id]["user_context"] = user_context # Send typing indicator await websocket.send_json({ "type": "typing", "is_typing": True }) # Process in background asyncio.create_task( process_websocket_message( websocket, session_id, user_message, domain, user_context, data ) ) elif message_type == "update_context" and session_id: # Update user context new_context = validate_user_context(data.get("user_context", "auto")) user_context = new_context if session_id in user_sessions: user_sessions[session_id]["user_context"] = new_context context_info = get_user_context_by_id(new_context) await websocket.send_json({ "type": "context_updated", "user_context": user_context, "context_info": context_info, "session_id": session_id }) elif message_type == "update_domain" and session_id: # Update domain new_domain = validate_domain(data.get("domain", "general_medical")) domain_info = get_domain_by_id(new_domain) await websocket.send_json({ "type": "domain_updated", "domain": new_domain, "domain_info": domain_info, "session_id": session_id }) elif message_type == "clear_history" and session_id: # Clear chat history if hasattr(chat_engine, 'clear_memory'): chat_engine.clear_memory() await websocket.send_json({ "type": "history_cleared", "session_id": session_id }) elif message_type == "get_domains": # Send domain list await websocket.send_json({ "type": "domains_list", "domains": MEDICAL_DOMAINS, "count": len(MEDICAL_DOMAINS) }) elif message_type == "get_contexts": # Send user contexts list await websocket.send_json({ "type": "contexts_list", "user_contexts": USER_CONTEXTS, "count": len(USER_CONTEXTS) }) except WebSocketDisconnect: if session_id and session_id in active_connections: del active_connections[session_id] logger.info(f"WebSocket disconnected: {session_id}") except Exception as e: logger.error(f"WebSocket error: {str(e)}") await websocket.send_json({ "type": "error", "message": f"Connection error: {str(e)}" }) async def process_websocket_message(websocket: WebSocket, session_id: str, user_message: str, domain: str, user_context: str, data: dict): """Process WebSocket message asynchronously""" try: # Process query response = await chat_engine.process_query_async( query=user_message, domain=domain, session_id=session_id, user_context=user_context, max_papers=data.get("max_papers", 15) ) # Send domain and context info domain_info = get_domain_by_id(domain) context_info = get_user_context_by_id(user_context) await websocket.send_json({ "type": "context_info", "user_context": response.get("user_context", user_context), "domain": domain, "domain_info": domain_info, "context_info": context_info }) # Send response in chunks (for streaming effect) answer = response.get("answer", "") chunks = split_into_chunks(answer, 200) for i, chunk in enumerate(chunks): await websocket.send_json({ "type": "message_chunk", "chunk": chunk, "is_final": i == len(chunks) - 1, "chunk_index": i, "total_chunks": len(chunks) }) await asyncio.sleep(0.05) # Small delay for streaming effect # Send complete message with metadata await websocket.send_json({ "type": "message_complete", "message": answer, "metadata": { "confidence_score": response.get("confidence_score", {}).get("overall_score", 0), "papers_used": response.get("papers_used", 0), "user_context": response.get("user_context", user_context), "domain": domain, "query_type": response.get("query_type", "general") } }) except Exception as e: logger.error(f"WebSocket message processing error: {str(e)}", exc_info=True) await websocket.send_json({ "type": "error", "message": f"Processing error: {str(e)}" }) @app.get("/api/v1/session/{session_id}") async def get_session_info(session_id: str): """Get session information""" if session_id not in user_sessions: return JSONResponse( status_code=404, content={"error": "Session not found"} ) session = user_sessions[session_id] # Get domain and context info domain_info = None if session.get("domains_used"): last_domain = list(session.get("domains_used"))[-1] if session.get("domains_used") else None domain_info = get_domain_by_id(last_domain) if last_domain else None context_info = get_user_context_by_id(session.get("user_context", "auto")) return { "session_id": session_id, "created_at": session.get("created_at"), "user_context": session.get("user_context", "auto"), "context_info": context_info, "message_count": session.get("message_count", 0), "last_active": session.get("last_active"), "domains_used": list(session.get("domains_used", [])), "last_domain_info": domain_info, "query_types": session.get("query_types", []), "features_enabled": [ "medical_research_analysis", "domain_specific_insights", "user_context_adaptation" ] } @app.put("/api/v1/session/{session_id}/context") async def update_session_context(session_id: str, request: dict): """Update session user context""" if session_id not in user_sessions: return JSONResponse( status_code=404, content={"error": "Session not found"} ) new_context = validate_user_context(request.get("user_context", "auto")) user_sessions[session_id]["user_context"] = new_context context_info = get_user_context_by_id(new_context) return { "success": True, "session_id": session_id, "user_context": new_context, "context_info": context_info, "message": f"User context updated to {new_context}" } @app.delete("/api/v1/session/{session_id}") async def delete_session(session_id: str): """Delete a session""" if session_id in user_sessions: # Clear engine memory if method exists if hasattr(chat_engine, 'clear_memory'): chat_engine.clear_memory() # Remove from storage del user_sessions[session_id] # Close WebSocket if active if session_id in active_connections: try: await active_connections[session_id].close() except: pass del active_connections[session_id] return {"success": True, "message": "Session deleted"} @app.get("/api/v1/engine/status") async def get_engine_status(): """Get engine status and metrics""" if hasattr(chat_engine, 'get_engine_status'): status = chat_engine.get_engine_status() return { "success": True, "engine": "Medical Research Engine", "domains_supported": len(MEDICAL_DOMAINS), "user_contexts_supported": len(USER_CONTEXTS), **status } return { "success": False, "engine": "Unknown", "message": "Engine status not available" } # ============================================================================ # DEVELOPMENT ONLY - Local server run # ============================================================================ if __name__ == "__main__" and os.getenv("VERCEL") is None: # Only run locally, not on Vercel import uvicorn print(f"\n{'=' * 60}") print(f"๐Ÿš€ STARTING MEDICAL RESEARCH AI SERVER (LOCAL)") print(f"{'=' * 60}") print(f"๐Ÿ“š API Docs: http://localhost:8000/api/docs") print(f"๐Ÿฅ Medical Domains: {len(MEDICAL_DOMAINS)}") print(f"๐Ÿ‘ค User Contexts: {len(USER_CONTEXTS)}") print(f"{'=' * 60}\n") uvicorn.run( app, host="0.0.0.0", port=8000, reload=True )