Spaces:
Running
Running
| """ | |
| Medical Research API Server - HuggingFace Compatible Version | |
| """ | |
| # Add this for HuggingFace compatibility | |
| import os | |
| import sys | |
| # Set HuggingFace specific settings | |
| if os.getenv('SPACE_ID'): | |
| # We're on HuggingFace Spaces | |
| print(f"🚀 Running on HuggingFace Space: {os.getenv('SPACE_ID')}") | |
| # Force port 7860 for HuggingFace | |
| os.environ['PORT'] = '7860' | |
| # Update CORS for HuggingFace | |
| ALLOWED_ORIGINS = [ | |
| "https://medical-research-ai.vercel.app", | |
| "http://localhost:3000", | |
| "https://paulhemb-medsearchpro.hf.space" | |
| ] | |
| else: | |
| ALLOWED_ORIGINS = ["*"] | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Import engine - Vercel compatible | |
| try: | |
| # Try relative import first (Vercel runs from api/ directory) | |
| from engine import MedicalResearchEngine | |
| except ImportError: | |
| try: | |
| # Try absolute import for local development | |
| from api.engine import MedicalResearchEngine | |
| except ImportError: | |
| print("⚠️ MedicalResearchEngine not found - using fallback mode") | |
| # Fallback class | |
| class MedicalResearchEngine: | |
| def __init__(self): | |
| pass | |
| async def process_query_async(self, **kwargs): | |
| return {"answer": "Engine not available", "papers_used": 0} | |
| # ============================================================================ | |
| # DOMAIN AND USER CONTEXT DEFINITIONS | |
| # ============================================================================ | |
| MEDICAL_DOMAINS = [ | |
| {"id": "internal_medicine", "name": "Internal Medicine", "icon": "🏥", | |
| "description": "General internal medicine and diagnosis"}, | |
| {"id": "endocrinology", "name": "Endocrinology", "icon": "🧬", | |
| "description": "Hormonal and metabolic disorders"}, | |
| {"id": "gastroenterology", "name": "Gastroenterology", "icon": "🩸", | |
| "description": "Digestive system disorders"}, | |
| {"id": "pulmonology", "name": "Pulmonology", "icon": "🫁", | |
| "description": "Respiratory diseases and lung disorders"}, | |
| {"id": "nephrology", "name": "Nephrology", "icon": "🧪", | |
| "description": "Kidney diseases and renal function"}, | |
| {"id": "hematology", "name": "Hematology", "icon": "🩸", | |
| "description": "Blood disorders and hematologic diseases"}, | |
| {"id": "infectious_disease", "name": "Infectious Diseases", "icon": "🦠", | |
| "description": "Infectious diseases and microbiology"}, | |
| {"id": "obstetrics_gynecology", "name": "Obstetrics & Gynecology", "icon": "🤰", | |
| "description": "Women's health, pregnancy and reproductive medicine"}, | |
| {"id": "pathology", "name": "Pathology", "icon": "🔬", | |
| "description": "Disease diagnosis through tissue examination"}, | |
| {"id": "laboratory_medicine", "name": "Laboratory Medicine", "icon": "🧪", | |
| "description": "Clinical laboratory testing and biomarkers"}, | |
| {"id": "bioinformatics", "name": "Bioinformatics", "icon": "💻", | |
| "description": "Computational analysis of biological data"}, | |
| {"id": "clinical_research", "name": "Clinical Research", "icon": "📊", | |
| "description": "Clinical trials and evidence-based medicine"}, | |
| {"id": "medical_imaging", "name": "Medical Imaging", "icon": "🩻", | |
| "description": "Medical imaging and radiology"}, | |
| {"id": "oncology", "name": "Oncology", "icon": "🦠", | |
| "description": "Cancer research and treatment"}, | |
| {"id": "cardiology", "name": "Cardiology", "icon": "❤️", | |
| "description": "Heart and cardiovascular diseases"}, | |
| {"id": "neurology", "name": "Neurology", "icon": "🧠", | |
| "description": "Brain and nervous system disorders"}, | |
| {"id": "pharmacology", "name": "Pharmacology", "icon": "💊", | |
| "description": "Drug therapy and medication management"}, | |
| {"id": "genomics", "name": "Genomics", "icon": "🧬", | |
| "description": "Genetic research and personalized medicine"}, | |
| {"id": "public_health", "name": "Public Health", "icon": "🌍", | |
| "description": "Population health and epidemiology"}, | |
| {"id": "surgery", "name": "Surgery", "icon": "⚕️", | |
| "description": "Surgical procedures and techniques"}, | |
| {"id": "pediatrics", "name": "Pediatrics", "icon": "👶", | |
| "description": "Child health and pediatric medicine"}, | |
| {"id": "psychiatry", "name": "Psychiatry", "icon": "🧠", | |
| "description": "Mental health and psychiatric disorders"}, | |
| {"id": "dermatology", "name": "Dermatology", "icon": "🦋", | |
| "description": "Skin diseases and dermatologic conditions"}, | |
| {"id": "orthopedics", "name": "Orthopedics", "icon": "🦴", | |
| "description": "Musculoskeletal disorders and bone health"}, | |
| {"id": "ophthalmology", "name": "Ophthalmology", "icon": "👁️", | |
| "description": "Eye diseases and vision care"}, | |
| {"id": "urology", "name": "Urology", "icon": "💧", | |
| "description": "Urinary system and male reproductive health"}, | |
| {"id": "emergency_medicine", "name": "Emergency Medicine", "icon": "🚑", | |
| "description": "Acute care and emergency response"}, | |
| {"id": "critical_care", "name": "Critical Care", "icon": "🏥", | |
| "description": "Intensive care and critical illness"}, | |
| {"id": "pain_medicine", "name": "Pain Medicine", "icon": "⚕️", | |
| "description": "Pain management and analgesia"}, | |
| {"id": "nutrition", "name": "Nutrition", "icon": "🥗", | |
| "description": "Clinical nutrition and dietary management"}, | |
| {"id": "allergy_immunology", "name": "Allergy & Immunology", "icon": "🤧", | |
| "description": "Allergic diseases and immune disorders"}, | |
| {"id": "rehabilitation_medicine", "name": "Rehabilitation Medicine", "icon": "♿", | |
| "description": "Physical therapy and recovery"}, | |
| {"id": "general_medical", "name": "General Medical", "icon": "⚕️", | |
| "description": "General medical research and clinical questions"}, | |
| {"id": "auto", "name": "Auto-detect", "icon": "🤖", | |
| "description": "Automatically detect domain from query"} | |
| ] | |
| USER_CONTEXTS = [ | |
| {"id": "auto", "name": "Auto-detect", "icon": "🤖", | |
| "description": "Automatically detect user context"}, | |
| {"id": "clinician", "name": "Clinician", "icon": "👨⚕️", | |
| "description": "Medical doctors, nurses, and healthcare providers"}, | |
| {"id": "researcher", "name": "Researcher", "icon": "🔬", | |
| "description": "Academic researchers and scientists"}, | |
| {"id": "student", "name": "Student", "icon": "🎓", | |
| "description": "Medical students and trainees"}, | |
| {"id": "administrator", "name": "Administrator", "icon": "💼", | |
| "description": "Healthcare administrators and managers"}, | |
| {"id": "patient", "name": "Patient", "icon": "👤", | |
| "description": "Patients and general public"}, | |
| {"id": "general", "name": "General", "icon": "👤", | |
| "description": "General audience"} | |
| ] | |
| VALID_DOMAINS: Set[str] = {domain["id"] for domain in MEDICAL_DOMAINS} | |
| VALID_USER_CONTEXTS: Set[str] = {context["id"] for context in USER_CONTEXTS} | |
| # ============================================================================ | |
| # PYDANTIC MODELS | |
| # ============================================================================ | |
| class SessionCreate(BaseModel): | |
| """Schema for creating a new session""" | |
| session_id: Optional[str] = None | |
| user_context: str = "auto" | |
| class ChatRequest(BaseModel): | |
| """Schema for chat request""" | |
| message: str | |
| session_id: str | |
| domain: Optional[str] = "general_medical" | |
| user_context: str = "auto" | |
| max_papers: int = 15 | |
| class ChatResponse(BaseModel): | |
| """Schema for chat response""" | |
| success: bool | |
| message: str | |
| session_id: str | |
| processing_time: Optional[float] = None | |
| confidence_score: Optional[float] = None | |
| papers_used: Optional[int] = None | |
| user_context: Optional[str] = None | |
| raw_response: Optional[Dict] = None | |
| error: Optional[str] = None | |
| # ============================================================================ | |
| # FASTAPI APP INITIALIZATION | |
| # ============================================================================ | |
| app = FastAPI( | |
| title="Medical Research AI", | |
| description="Medical Research Assistant with Evidence-Based Analysis", | |
| version="1.0.0", | |
| docs_url="/api/docs", | |
| redoc_url="/api/redoc" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["ALLOWED_ORIGINS"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files (only if directory exists) | |
| static_dir = "static" | |
| if os.path.exists(static_dir): | |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
| else: | |
| logger.warning(f"Static directory '{static_dir}' not found") | |
| # Templates (only if directory exists) | |
| templates_dir = "templates" | |
| if os.path.exists(templates_dir): | |
| templates = Jinja2Templates(directory=templates_dir) | |
| else: | |
| templates = None | |
| logger.warning(f"Templates directory '{templates_dir}' not found") | |
| # Initialize chat engine | |
| chat_engine = MedicalResearchEngine() | |
| # Active WebSocket connections | |
| active_connections: Dict[str, WebSocket] = {} | |
| # Session storage | |
| user_sessions: Dict[str, Dict] = {} | |
| # ============================================================================ | |
| # HELPER FUNCTIONS | |
| # ============================================================================ | |
| def validate_domain(domain: str) -> str: | |
| """Validate and normalize domain""" | |
| if domain not in VALID_DOMAINS: | |
| logger.warning(f"Invalid domain '{domain}', defaulting to 'general_medical'") | |
| return "general_medical" | |
| return domain | |
| def validate_user_context(user_context: str) -> str: | |
| """Validate and normalize user context""" | |
| if user_context not in VALID_USER_CONTEXTS: | |
| logger.warning(f"Invalid user_context '{user_context}', defaulting to 'auto'") | |
| return "auto" | |
| return user_context | |
| def get_domain_by_id(domain_id: str) -> Optional[Dict]: | |
| """Get domain info by ID""" | |
| for domain in MEDICAL_DOMAINS: | |
| if domain["id"] == domain_id: | |
| return domain | |
| return None | |
| def get_user_context_by_id(context_id: str) -> Optional[Dict]: | |
| """Get user context info by ID""" | |
| for context in USER_CONTEXTS: | |
| if context["id"] == context_id: | |
| return context | |
| return None | |
| def split_into_chunks(text: str, chunk_size: int = 200) -> List[str]: | |
| """Split text into chunks for streaming""" | |
| return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] | |
| # ============================================================================ | |
| # ROUTES | |
| # ============================================================================ | |
| async def home(request: Request): | |
| """Serve the chat interface""" | |
| if templates: | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| # Fallback HTML if templates not found | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Medical Research AI</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; margin: 0; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); min-height: 100vh; } | |
| .container { max-width: 800px; margin: 50px auto; background: white; padding: 40px; border-radius: 15px; box-shadow: 0 20px 60px rgba(0,0,0,0.3); } | |
| h1 { color: #333; margin-bottom: 10px; } | |
| .tagline { color: #666; font-size: 18px; margin-bottom: 30px; } | |
| .stats { display: flex; justify-content: space-between; margin: 30px 0; } | |
| .stat { text-align: center; flex: 1; padding: 20px; } | |
| .stat-number { font-size: 36px; font-weight: bold; color: #667eea; } | |
| .stat-label { color: #666; margin-top: 5px; } | |
| .api-link { display: block; margin: 15px 0; padding: 15px; background: #f8f9fa; border-radius: 8px; text-decoration: none; color: #333; border-left: 4px solid #667eea; transition: all 0.3s; } | |
| .api-link:hover { background: #e9ecef; transform: translateX(5px); } | |
| .btn { display: inline-block; padding: 12px 24px; background: #667eea; color: white; text-decoration: none; border-radius: 6px; margin: 10px 5px; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>🏥 Medical Research AI</h1> | |
| <p class="tagline">Evidence-based medical research assistant with AI-powered insights</p> | |
| <div class="stats"> | |
| <div class="stat"> | |
| <div class="stat-number">34</div> | |
| <div class="stat-label">Medical Domains</div> | |
| </div> | |
| <div class="stat"> | |
| <div class="stat-number">7</div> | |
| <div class="stat-label">User Contexts</div> | |
| </div> | |
| <div class="stat"> | |
| <div class="stat-number">API</div> | |
| <div class="stat-label">Ready</div> | |
| </div> | |
| </div> | |
| <h2>📚 API Documentation</h2> | |
| <a href="/api/docs" class="api-link">📖 OpenAPI/Swagger Documentation</a> | |
| <a href="/api/redoc" class="api-link">📄 ReDoc Documentation</a> | |
| <h2>🔧 API Endpoints</h2> | |
| <a href="/api/health" class="api-link">❤️ Health Check</a> | |
| <a href="/api/v1/domains" class="api-link">🏥 Available Medical Domains</a> | |
| <a href="/api/v1/user_contexts" class="api-link">👤 User Contexts</a> | |
| <h2>🚀 Quick Start</h2> | |
| <div style="margin: 20px 0;"> | |
| <a href="/api/docs" class="btn">View API Docs</a> | |
| <a href="https://github.com/yourusername/medical-research-ai" class="btn" style="background: #333;">GitHub</a> | |
| </div> | |
| <div style="margin-top: 30px; padding-top: 20px; border-top: 1px solid #eee; color: #666; font-size: 14px;"> | |
| <p>🚀 Deployed on Vercel | ⚡ FastAPI | 🧬 Medical AI</p> | |
| </div> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html_content) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| engine_status = chat_engine.get_engine_status() if hasattr(chat_engine, 'get_engine_status') else {} | |
| return { | |
| "status": "healthy", | |
| "engine": "Medical Research Engine", | |
| "version": "1.0.0", | |
| "timestamp": datetime.now().isoformat(), | |
| "engine_configured": chat_engine.api_configured if hasattr(chat_engine, 'api_configured') else False, | |
| "features": [ | |
| "Evidence-Based Medical Analysis", | |
| "Domain-Specific Research", | |
| "User Context Adaptation", | |
| "Paper Summarization" | |
| ], | |
| "stats": { | |
| "domains_count": len(MEDICAL_DOMAINS), | |
| "user_contexts_count": len(USER_CONTEXTS), | |
| "active_sessions": len(user_sessions), | |
| "active_connections": len(active_connections) | |
| } | |
| } | |
| async def get_domains(): | |
| """Get all available medical domains""" | |
| return { | |
| "success": True, | |
| "domains": MEDICAL_DOMAINS, | |
| "count": len(MEDICAL_DOMAINS), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def get_domain_info(domain_id: str): | |
| """Get information about a specific domain""" | |
| domain = get_domain_by_id(domain_id) | |
| if not domain: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": f"Domain '{domain_id}' not found"} | |
| ) | |
| return { | |
| "success": True, | |
| "domain": domain, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def get_user_contexts(): | |
| """Get all available user contexts""" | |
| return { | |
| "success": True, | |
| "user_contexts": USER_CONTEXTS, | |
| "count": len(USER_CONTEXTS), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def create_session(request: SessionCreate = None): | |
| """Create a new chat session""" | |
| if request is None: | |
| request = SessionCreate() | |
| session_id = request.session_id or str(uuid.uuid4()) | |
| user_context = validate_user_context(request.user_context) | |
| user_sessions[session_id] = { | |
| "id": session_id, | |
| "created_at": datetime.now().isoformat(), | |
| "user_context": user_context, | |
| "message_count": 0, | |
| "domains_used": set(), | |
| "last_active": datetime.now().isoformat() | |
| } | |
| # Initialize engine for this session | |
| if hasattr(chat_engine, 'initialize_session'): | |
| chat_engine.initialize_session(session_id) | |
| context_info = get_user_context_by_id(user_context) | |
| return { | |
| "session_id": session_id, | |
| "user_context": user_context, | |
| "context_info": context_info, | |
| "created_at": user_sessions[session_id]["created_at"], | |
| "welcome_message": f"""🎉 **Welcome to Medical Research Assistant!** 🧬 | |
| 👤 **Your session context:** {context_info['name'] if context_info else user_context} | |
| 🏥 **Available Specialties:** {len(MEDICAL_DOMAINS) - 2} medical domains | |
| I can help you with: | |
| • **Medical Research Analysis** - Evidence-based insights | |
| • **Paper Summarization** - Key findings and implications | |
| • **Clinical Context** - Domain-specific applications | |
| • **Research Gap Identification** - Future directions | |
| **Try asking:** | |
| • "Latest treatments for diabetes" (Endocrinology) | |
| • "Research on Alzheimer's biomarkers" (Neurology) | |
| • "Clinical guidelines for hypertension" (Cardiology) | |
| • "Summarize recent advances in cancer immunotherapy" (Oncology) | |
| I'll adapt my responses based on your role and medical domain.""" | |
| } | |
| async def chat_endpoint(request: ChatRequest): | |
| """Process chat message""" | |
| try: | |
| # Validate inputs | |
| domain = validate_domain(request.domain) | |
| user_context = validate_user_context(request.user_context) | |
| # Validate max_papers | |
| if request.max_papers < 1 or request.max_papers > 50: | |
| request.max_papers = min(max(request.max_papers, 1), 50) | |
| # Update session activity | |
| if request.session_id in user_sessions: | |
| session = user_sessions[request.session_id] | |
| session["last_active"] = datetime.now().isoformat() | |
| session["message_count"] += 1 | |
| session["domains_used"].add(domain) | |
| # Use session user_context if available | |
| if session.get("user_context"): | |
| user_context = session["user_context"] | |
| else: | |
| session["user_context"] = user_context | |
| logger.info(f"Processing chat - Domain: {domain}, Context: {user_context}") | |
| # Process the query | |
| start_time = datetime.now() | |
| response = await chat_engine.process_query_async( | |
| query=request.message, | |
| domain=domain, | |
| session_id=request.session_id, | |
| user_context=user_context, | |
| max_papers=request.max_papers | |
| ) | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| # Track query type | |
| if request.session_id in user_sessions: | |
| query_type = response.get("query_type", "unknown") | |
| if "query_types" not in user_sessions[request.session_id]: | |
| user_sessions[request.session_id]["query_types"] = [] | |
| user_sessions[request.session_id]["query_types"].append(query_type) | |
| # Format response | |
| return ChatResponse( | |
| success=True, | |
| message=response.get("answer", "No response generated"), | |
| session_id=request.session_id, | |
| processing_time=processing_time, | |
| confidence_score=response.get("confidence_score", {}).get("overall_score", 0), | |
| papers_used=response.get("papers_used", 0), | |
| user_context=response.get("user_context", user_context), | |
| raw_response=response | |
| ) | |
| except Exception as e: | |
| logger.error(f"Chat endpoint error: {str(e)}", exc_info=True) | |
| return ChatResponse( | |
| success=False, | |
| message=f"❌ Error: {str(e)}", | |
| session_id=request.session_id, | |
| error=str(e), | |
| user_context=request.user_context | |
| ) | |
| async def websocket_chat(websocket: WebSocket): | |
| """WebSocket for real-time chat""" | |
| await websocket.accept() | |
| session_id = None | |
| user_context = "auto" | |
| try: | |
| while True: | |
| # Receive message | |
| data = await websocket.receive_json() | |
| message_type = data.get("type") | |
| if message_type == "init_session": | |
| # Create or get session | |
| session_id = data.get("session_id") or str(uuid.uuid4()) | |
| user_context = validate_user_context(data.get("user_context", "auto")) | |
| if session_id not in user_sessions: | |
| user_sessions[session_id] = { | |
| "id": session_id, | |
| "created_at": datetime.now().isoformat(), | |
| "user_context": user_context, | |
| "message_count": 0, | |
| "websocket": websocket | |
| } | |
| if hasattr(chat_engine, 'initialize_session'): | |
| chat_engine.initialize_session(session_id) | |
| active_connections[session_id] = websocket | |
| await websocket.send_json({ | |
| "type": "session_created", | |
| "session_id": session_id, | |
| "user_context": user_context, | |
| "timestamp": datetime.now().isoformat(), | |
| "features": [ | |
| "medical_research_analysis", | |
| "domain_specific_insights", | |
| "user_context_adaptation" | |
| ], | |
| "stats": { | |
| "domains_available": len(MEDICAL_DOMAINS), | |
| "user_contexts_available": len(USER_CONTEXTS) | |
| } | |
| }) | |
| elif message_type == "message" and session_id: | |
| # Process chat message | |
| user_message = data.get("message", "") | |
| domain = validate_domain(data.get("domain", "general_medical")) | |
| user_context = validate_user_context(data.get("user_context", user_context)) | |
| # Update session context | |
| if session_id in user_sessions: | |
| user_sessions[session_id]["user_context"] = user_context | |
| # Send typing indicator | |
| await websocket.send_json({ | |
| "type": "typing", | |
| "is_typing": True | |
| }) | |
| # Process in background | |
| asyncio.create_task( | |
| process_websocket_message( | |
| websocket, session_id, user_message, | |
| domain, user_context, data | |
| ) | |
| ) | |
| elif message_type == "update_context" and session_id: | |
| # Update user context | |
| new_context = validate_user_context(data.get("user_context", "auto")) | |
| user_context = new_context | |
| if session_id in user_sessions: | |
| user_sessions[session_id]["user_context"] = new_context | |
| context_info = get_user_context_by_id(new_context) | |
| await websocket.send_json({ | |
| "type": "context_updated", | |
| "user_context": user_context, | |
| "context_info": context_info, | |
| "session_id": session_id | |
| }) | |
| elif message_type == "update_domain" and session_id: | |
| # Update domain | |
| new_domain = validate_domain(data.get("domain", "general_medical")) | |
| domain_info = get_domain_by_id(new_domain) | |
| await websocket.send_json({ | |
| "type": "domain_updated", | |
| "domain": new_domain, | |
| "domain_info": domain_info, | |
| "session_id": session_id | |
| }) | |
| elif message_type == "clear_history" and session_id: | |
| # Clear chat history | |
| if hasattr(chat_engine, 'clear_memory'): | |
| chat_engine.clear_memory() | |
| await websocket.send_json({ | |
| "type": "history_cleared", | |
| "session_id": session_id | |
| }) | |
| elif message_type == "get_domains": | |
| # Send domain list | |
| await websocket.send_json({ | |
| "type": "domains_list", | |
| "domains": MEDICAL_DOMAINS, | |
| "count": len(MEDICAL_DOMAINS) | |
| }) | |
| elif message_type == "get_contexts": | |
| # Send user contexts list | |
| await websocket.send_json({ | |
| "type": "contexts_list", | |
| "user_contexts": USER_CONTEXTS, | |
| "count": len(USER_CONTEXTS) | |
| }) | |
| except WebSocketDisconnect: | |
| if session_id and session_id in active_connections: | |
| del active_connections[session_id] | |
| logger.info(f"WebSocket disconnected: {session_id}") | |
| except Exception as e: | |
| logger.error(f"WebSocket error: {str(e)}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Connection error: {str(e)}" | |
| }) | |
| async def process_websocket_message(websocket: WebSocket, session_id: str, | |
| user_message: str, domain: str, | |
| user_context: str, data: dict): | |
| """Process WebSocket message asynchronously""" | |
| try: | |
| # Process query | |
| response = await chat_engine.process_query_async( | |
| query=user_message, | |
| domain=domain, | |
| session_id=session_id, | |
| user_context=user_context, | |
| max_papers=data.get("max_papers", 15) | |
| ) | |
| # Send domain and context info | |
| domain_info = get_domain_by_id(domain) | |
| context_info = get_user_context_by_id(user_context) | |
| await websocket.send_json({ | |
| "type": "context_info", | |
| "user_context": response.get("user_context", user_context), | |
| "domain": domain, | |
| "domain_info": domain_info, | |
| "context_info": context_info | |
| }) | |
| # Send response in chunks (for streaming effect) | |
| answer = response.get("answer", "") | |
| chunks = split_into_chunks(answer, 200) | |
| for i, chunk in enumerate(chunks): | |
| await websocket.send_json({ | |
| "type": "message_chunk", | |
| "chunk": chunk, | |
| "is_final": i == len(chunks) - 1, | |
| "chunk_index": i, | |
| "total_chunks": len(chunks) | |
| }) | |
| await asyncio.sleep(0.05) # Small delay for streaming effect | |
| # Send complete message with metadata | |
| await websocket.send_json({ | |
| "type": "message_complete", | |
| "message": answer, | |
| "metadata": { | |
| "confidence_score": response.get("confidence_score", {}).get("overall_score", 0), | |
| "papers_used": response.get("papers_used", 0), | |
| "user_context": response.get("user_context", user_context), | |
| "domain": domain, | |
| "query_type": response.get("query_type", "general") | |
| } | |
| }) | |
| except Exception as e: | |
| logger.error(f"WebSocket message processing error: {str(e)}", exc_info=True) | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Processing error: {str(e)}" | |
| }) | |
| async def get_session_info(session_id: str): | |
| """Get session information""" | |
| if session_id not in user_sessions: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": "Session not found"} | |
| ) | |
| session = user_sessions[session_id] | |
| # Get domain and context info | |
| domain_info = None | |
| if session.get("domains_used"): | |
| last_domain = list(session.get("domains_used"))[-1] if session.get("domains_used") else None | |
| domain_info = get_domain_by_id(last_domain) if last_domain else None | |
| context_info = get_user_context_by_id(session.get("user_context", "auto")) | |
| return { | |
| "session_id": session_id, | |
| "created_at": session.get("created_at"), | |
| "user_context": session.get("user_context", "auto"), | |
| "context_info": context_info, | |
| "message_count": session.get("message_count", 0), | |
| "last_active": session.get("last_active"), | |
| "domains_used": list(session.get("domains_used", [])), | |
| "last_domain_info": domain_info, | |
| "query_types": session.get("query_types", []), | |
| "features_enabled": [ | |
| "medical_research_analysis", | |
| "domain_specific_insights", | |
| "user_context_adaptation" | |
| ] | |
| } | |
| async def update_session_context(session_id: str, request: dict): | |
| """Update session user context""" | |
| if session_id not in user_sessions: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": "Session not found"} | |
| ) | |
| new_context = validate_user_context(request.get("user_context", "auto")) | |
| user_sessions[session_id]["user_context"] = new_context | |
| context_info = get_user_context_by_id(new_context) | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "user_context": new_context, | |
| "context_info": context_info, | |
| "message": f"User context updated to {new_context}" | |
| } | |
| async def delete_session(session_id: str): | |
| """Delete a session""" | |
| if session_id in user_sessions: | |
| # Clear engine memory if method exists | |
| if hasattr(chat_engine, 'clear_memory'): | |
| chat_engine.clear_memory() | |
| # Remove from storage | |
| del user_sessions[session_id] | |
| # Close WebSocket if active | |
| if session_id in active_connections: | |
| try: | |
| await active_connections[session_id].close() | |
| except: | |
| pass | |
| del active_connections[session_id] | |
| return {"success": True, "message": "Session deleted"} | |
| async def get_engine_status(): | |
| """Get engine status and metrics""" | |
| if hasattr(chat_engine, 'get_engine_status'): | |
| status = chat_engine.get_engine_status() | |
| return { | |
| "success": True, | |
| "engine": "Medical Research Engine", | |
| "domains_supported": len(MEDICAL_DOMAINS), | |
| "user_contexts_supported": len(USER_CONTEXTS), | |
| **status | |
| } | |
| return { | |
| "success": False, | |
| "engine": "Unknown", | |
| "message": "Engine status not available" | |
| } | |
| # ============================================================================ | |
| # DEVELOPMENT ONLY - Local server run | |
| # ============================================================================ | |
| if __name__ == "__main__" and os.getenv("VERCEL") is None: | |
| # Only run locally, not on Vercel | |
| import uvicorn | |
| print(f"\n{'=' * 60}") | |
| print(f"🚀 STARTING MEDICAL RESEARCH AI SERVER (LOCAL)") | |
| print(f"{'=' * 60}") | |
| print(f"📚 API Docs: http://localhost:8000/api/docs") | |
| print(f"🏥 Medical Domains: {len(MEDICAL_DOMAINS)}") | |
| print(f"👤 User Contexts: {len(USER_CONTEXTS)}") | |
| print(f"{'=' * 60}\n") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True | |
| ) | |