Spaces:
Running
Running
| """ | |
| api/main.py - Production-Ready Medical Research API Server for Vercel | |
| Updated to support role-based reasoning and domain-agnostic responses | |
| """ | |
| from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import Dict, List, Optional, Set | |
| from pydantic import BaseModel | |
| import asyncio | |
| import json | |
| import os | |
| from datetime import datetime | |
| import uuid | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Import engine - Vercel compatible | |
| try: | |
| # Try relative import first (Vercel runs from api/ directory) | |
| from engine import MedicalResearchEngine | |
| except ImportError: | |
| try: | |
| # Try absolute import for local development | |
| from api.engine import MedicalResearchEngine | |
| except ImportError: | |
| print("β οΈ MedicalResearchEngine not found - using fallback mode") | |
| # Fallback class | |
| class MedicalResearchEngine: | |
| def __init__(self): | |
| pass | |
| async def process_query_async(self, **kwargs): | |
| return {"answer": "Engine not available", "papers_used": 0} | |
| # ============================================================================ | |
| # DOMAIN AND USER CONTEXT DEFINITIONS (UPDATED) | |
| # ============================================================================ | |
| # Update domain definitions to include all domains from rag_engine.py | |
| MEDICAL_DOMAINS = [ | |
| {"id": "internal_medicine", "name": "Internal Medicine", "icon": "π₯", | |
| "description": "General internal medicine and diagnosis"}, | |
| {"id": "endocrinology", "name": "Endocrinology", "icon": "π§¬", | |
| "description": "Hormonal and metabolic disorders"}, | |
| {"id": "cardiology", "name": "Cardiology", "icon": "β€οΈ", | |
| "description": "Heart and cardiovascular diseases"}, | |
| {"id": "neurology", "name": "Neurology", "icon": "π§ ", | |
| "description": "Brain and nervous system disorders"}, | |
| {"id": "oncology", "name": "Oncology", "icon": "π¦ ", | |
| "description": "Cancer research and treatment"}, | |
| {"id": "infectious_disease", "name": "Infectious Diseases", "icon": "π¦ ", | |
| "description": "Infectious diseases and microbiology"}, | |
| {"id": "clinical_research", "name": "Clinical Research", "icon": "π", | |
| "description": "Clinical trials and evidence-based medicine"}, | |
| {"id": "general_medical", "name": "General Medical", "icon": "βοΈ", | |
| "description": "General medical research and clinical questions"}, | |
| {"id": "pulmonology", "name": "Pulmonology", "icon": "π«", | |
| "description": "Respiratory diseases and lung health"}, | |
| {"id": "gastroenterology", "name": "Gastroenterology", "icon": "π½οΈ", | |
| "description": "Digestive system disorders"}, | |
| {"id": "nephrology", "name": "Nephrology", "icon": "π«", | |
| "description": "Kidney diseases and disorders"}, | |
| {"id": "hematology", "name": "Hematology", "icon": "π©Έ", | |
| "description": "Blood disorders and hematologic diseases"}, | |
| {"id": "surgery", "name": "Surgery", "icon": "πͺ", | |
| "description": "Surgical procedures and interventions"}, | |
| {"id": "orthopedics", "name": "Orthopedics", "icon": "π¦΄", | |
| "description": "Musculoskeletal disorders and injuries"}, | |
| {"id": "urology", "name": "Urology", "icon": "π½", | |
| "description": "Urinary tract and male reproductive system"}, | |
| {"id": "ophthalmology", "name": "Ophthalmology", "icon": "ποΈ", | |
| "description": "Eye diseases and vision disorders"}, | |
| {"id": "dermatology", "name": "Dermatology", "icon": "π¦", | |
| "description": "Skin diseases and disorders"}, | |
| {"id": "psychiatry", "name": "Psychiatry", "icon": "π§", | |
| "description": "Mental health and psychiatric disorders"}, | |
| {"id": "obstetrics_gynecology", "name": "Obstetrics & Gynecology", "icon": "π€°", | |
| "description": "Women's health and reproductive medicine"}, | |
| {"id": "pediatrics", "name": "Pediatrics", "icon": "πΆ", | |
| "description": "Child health and pediatric medicine"}, | |
| {"id": "emergency_medicine", "name": "Emergency Medicine", "icon": "π", | |
| "description": "Emergency care and acute medicine"}, | |
| {"id": "critical_care", "name": "Critical Care Medicine", "icon": "π₯", | |
| "description": "Intensive care and critical care medicine"}, | |
| {"id": "pathology", "name": "Pathology", "icon": "π¬", | |
| "description": "Disease diagnosis and laboratory medicine"}, | |
| {"id": "laboratory_medicine", "name": "Laboratory Medicine", "icon": "π§ͺ", | |
| "description": "Clinical laboratory testing and diagnostics"}, | |
| {"id": "medical_imaging", "name": "Medical Imaging & Radiology AI", "icon": "π·", | |
| "description": "Medical imaging and radiological diagnosis"}, | |
| {"id": "bioinformatics", "name": "Bioinformatics", "icon": "π»", | |
| "description": "Computational biology and data analysis"}, | |
| {"id": "genomics", "name": "Genomics & Sequencing", "icon": "π§¬", | |
| "description": "Genomic research and sequencing technologies"}, | |
| {"id": "pharmacology", "name": "Pharmacology", "icon": "π", | |
| "description": "Drug research and pharmacology"}, | |
| {"id": "public_health", "name": "Public Health Analytics", "icon": "π", | |
| "description": "Public health and epidemiology"}, | |
| {"id": "pain_medicine", "name": "Pain Medicine", "icon": "π©Ή", | |
| "description": "Pain management and treatment"}, | |
| {"id": "nutrition", "name": "Nutrition", "icon": "π", | |
| "description": "Nutritional science and dietetics"}, | |
| {"id": "allergy_immunology", "name": "Allergy & Immunology", "icon": "π€§", | |
| "description": "Allergies and immune system disorders"}, | |
| {"id": "rehabilitation_medicine", "name": "Rehabilitation Medicine", "icon": "βΏ", | |
| "description": "Physical medicine and rehabilitation"}, | |
| {"id": "auto", "name": "Auto-detect", "icon": "π", | |
| "description": "Automatic domain detection"} | |
| ] | |
| # Update USER_CONTEXTS to match RoleBasedReasoning roles from rag_engine.py | |
| USER_CONTEXTS = [ | |
| {"id": "patient", "name": "Patient", "icon": "π©Ί", | |
| "description": "Patients and general public seeking health information"}, | |
| {"id": "student", "name": "Student", "icon": "π", | |
| "description": "Medical students and trainees"}, | |
| {"id": "clinician", "name": "Clinician", "icon": "π¨ββοΈ", | |
| "description": "Healthcare providers and nurses"}, | |
| {"id": "doctor", "name": "Doctor", "icon": "βοΈ", | |
| "description": "Medical doctors and physicians"}, | |
| {"id": "researcher", "name": "Researcher", "icon": "π¬", | |
| "description": "Academic researchers and scientists"}, | |
| {"id": "professor", "name": "Professor", "icon": "π", | |
| "description": "Academic educators and professors"}, | |
| {"id": "pharmacist", "name": "Pharmacist", "icon": "π", | |
| "description": "Pharmacy professionals and pharmacists"}, | |
| {"id": "general", "name": "General User", "icon": "π€", | |
| "description": "General audience"}, | |
| {"id": "auto", "name": "Auto-detect", "icon": "π€", | |
| "description": "Automatically detect user role"} | |
| ] | |
| VALID_DOMAINS: Set[str] = {domain["id"] for domain in MEDICAL_DOMAINS} | |
| VALID_USER_CONTEXTS: Set[str] = {context["id"] for context in USER_CONTEXTS} | |
| # ============================================================================ | |
| # PYDANTIC MODELS (UPDATED FOR ROLE-BASED REASONING) | |
| # ============================================================================ | |
| class SessionCreate(BaseModel): | |
| """Schema for creating a new session""" | |
| session_id: Optional[str] = None | |
| user_role: str = "auto" # Changed from user_context to user_role | |
| custom_role_prompt: Optional[str] = None # New: Custom role prompt | |
| class ChatRequest(BaseModel): | |
| """Schema for chat request - updated for role-based reasoning""" | |
| message: str | |
| session_id: str | |
| domain: Optional[str] = "general_medical" | |
| user_role: str = "auto" # Changed from user_context | |
| custom_role_prompt: Optional[str] = None # New: Custom role prompt | |
| max_papers: int = 15 | |
| use_real_time: Optional[bool] = True # New: Control real-time search | |
| use_fallback: Optional[bool] = False # New: Use fallback papers | |
| class ChatResponse(BaseModel): | |
| """Schema for chat response - updated for role-based reasoning""" | |
| success: bool | |
| message: str | |
| session_id: str | |
| processing_time: Optional[float] = None | |
| confidence_score: Optional[float] = None | |
| papers_used: Optional[int] = None | |
| real_papers: Optional[int] = None | |
| demo_papers: Optional[int] = None | |
| user_role: Optional[str] = None # Changed from user_context | |
| domain: Optional[str] = None | |
| reasoning_method: Optional[str] = None # New: Type of reasoning used | |
| raw_response: Optional[Dict] = None | |
| error: Optional[str] = None | |
| # ============================================================================ | |
| # FASTAPI APP INITIALIZATION | |
| # ============================================================================ | |
| app = FastAPI( | |
| title="Medical Research AI with Role-Based Reasoning", | |
| description="Medical Research Assistant with Evidence-Based Analysis and Role-Based Responses", | |
| version="2.2.0", | |
| docs_url="/api/docs", | |
| redoc_url="/api/redoc" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files (only if directory exists) | |
| static_dir = "static" | |
| if os.path.exists(static_dir): | |
| app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
| else: | |
| logger.warning(f"Static directory '{static_dir}' not found") | |
| # Templates (only if directory exists) | |
| templates_dir = "templates" | |
| if os.path.exists(templates_dir): | |
| templates = Jinja2Templates(directory=templates_dir) | |
| else: | |
| templates = None | |
| logger.warning(f"Templates directory '{templates_dir}' not found") | |
| # Initialize chat engine | |
| chat_engine = MedicalResearchEngine() | |
| # Active WebSocket connections | |
| active_connections: Dict[str, WebSocket] = {} | |
| # Session storage | |
| user_sessions: Dict[str, Dict] = {} | |
| # ============================================================================ | |
| # HELPER FUNCTIONS (UPDATED) | |
| # ============================================================================ | |
| def validate_domain(domain: str) -> str: | |
| """Validate and normalize domain""" | |
| if domain not in VALID_DOMAINS: | |
| logger.warning(f"Invalid domain '{domain}', defaulting to 'general_medical'") | |
| return "general_medical" | |
| return domain | |
| def validate_user_role(user_role: str) -> str: | |
| """Validate and normalize user role""" | |
| if user_role not in VALID_USER_CONTEXTS: | |
| logger.warning(f"Invalid user_role '{user_role}', defaulting to 'general'") | |
| return "general" | |
| return user_role | |
| def get_domain_by_id(domain_id: str) -> Optional[Dict]: | |
| """Get domain info by ID""" | |
| for domain in MEDICAL_DOMAINS: | |
| if domain["id"] == domain_id: | |
| return domain | |
| return None | |
| def get_user_role_by_id(role_id: str) -> Optional[Dict]: | |
| """Get user role info by ID""" | |
| for role in USER_CONTEXTS: | |
| if role["id"] == role_id: | |
| return role | |
| return None | |
| def split_into_chunks(text: str, chunk_size: int = 200) -> List[str]: | |
| """Split text into chunks for streaming""" | |
| return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] | |
| # ============================================================================ | |
| # ROUTES (UPDATED FOR ROLE-BASED REASONING) | |
| # ============================================================================ | |
| async def home(request: Request): | |
| """Serve the chat interface with role-based features""" | |
| if templates: | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| # Fallback HTML if templates not found | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Medical Research AI with Role-Based Reasoning</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; margin: 0; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); min-height: 100vh; } | |
| .container { max-width: 800px; margin: 50px auto; background: white; padding: 40px; border-radius: 15px; box-shadow: 0 20px 60px rgba(0,0,0,0.3); } | |
| h1 { color: #333; margin-bottom: 10px; } | |
| .tagline { color: #666; font-size: 18px; margin-bottom: 30px; } | |
| .stats { display: flex; justify-content: space-between; margin: 30px 0; } | |
| .stat { text-align: center; flex: 1; padding: 20px; } | |
| .stat-number { font-size: 36px; font-weight: bold; color: #667eea; } | |
| .stat-label { color: #666; margin-top: 5px; } | |
| .api-link { display: block; margin: 15px 0; padding: 15px; background: #f8f9fa; border-radius: 8px; text-decoration: none; color: #333; border-left: 4px solid #667eea; transition: all 0.3s; } | |
| .api-link:hover { background: #e9ecef; transform: translateX(5px); } | |
| .btn { display: inline-block; padding: 12px 24px; background: #667eea; color: white; text-decoration: none; border-radius: 6px; margin: 10px 5px; } | |
| .feature-grid { display: grid; grid-template-columns: repeat(2, 1fr); gap: 20px; margin: 30px 0; } | |
| .feature-card { padding: 20px; background: #f8f9fa; border-radius: 10px; border-left: 4px solid #667eea; } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>π₯ Medical Research AI with Role-Based Reasoning</h1> | |
| <p class="tagline">Evidence-based medical research assistant with AI-powered insights tailored to your role</p> | |
| <div class="stats"> | |
| <div class="stat"> | |
| <div class="stat-number">36</div> | |
| <div class="stat-label">Medical Domains</div> | |
| </div> | |
| <div class="stat"> | |
| <div class="stat-number">8</div> | |
| <div class="stat-label">User Roles</div> | |
| </div> | |
| <div class="stat"> | |
| <div class="stat-number">Role-Based</div> | |
| <div class="stat-label">Responses</div> | |
| </div> | |
| </div> | |
| <h2>π Key Features</h2> | |
| <div class="feature-grid"> | |
| <div class="feature-card"> | |
| <strong>π€ Role-Based Responses</strong> | |
| <p>Tailored answers for patients, doctors, researchers, and more</p> | |
| </div> | |
| <div class="feature-card"> | |
| <strong>π₯ Domain-Specific</strong> | |
| <p>36 medical specialties with specialized knowledge</p> | |
| </div> | |
| <div class="feature-card"> | |
| <strong>π¬ Evidence-Based</strong> | |
| <p>Research-backed answers with confidence scoring</p> | |
| </div> | |
| <div class="feature-card"> | |
| <strong>π Guideline Detection</strong> | |
| <p>Automatic detection of clinical guidelines</p> | |
| </div> | |
| </div> | |
| <h2>π API Documentation</h2> | |
| <a href="/api/docs" class="api-link">π OpenAPI/Swagger Documentation</a> | |
| <a href="/api/redoc" class="api-link">π ReDoc Documentation</a> | |
| <h2>π§ API Endpoints</h2> | |
| <a href="/api/health" class="api-link">β€οΈ Health Check</a> | |
| <a href="/api/v1/domains" class="api-link">π₯ Available Medical Domains</a> | |
| <a href="/api/v1/roles" class="api-link">π€ User Roles</a> | |
| <a href="/api/v1/engine/status" class="api-link">βοΈ Engine Status</a> | |
| <h2>π Quick Start</h2> | |
| <div style="margin: 20px 0;"> | |
| <a href="/api/docs" class="btn">View API Docs</a> | |
| <a href="https://github.com/yourusername/medical-research-ai" class="btn" style="background: #333;">GitHub</a> | |
| </div> | |
| <div style="margin-top: 30px; padding-top: 20px; border-top: 1px solid #eee; color: #666; font-size: 14px;"> | |
| <p>π Deployed on Vercel | β‘ FastAPI | 𧬠Medical AI | π€ Role-Based Reasoning</p> | |
| </div> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html_content) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| engine_status = chat_engine.get_engine_status() if hasattr(chat_engine, 'get_engine_status') else {} | |
| return { | |
| "status": "healthy", | |
| "engine": "Medical Research Engine with Role-Based Reasoning", | |
| "version": "2.2.0", | |
| "timestamp": datetime.now().isoformat(), | |
| "engine_configured": chat_engine.api_configured if hasattr(chat_engine, 'api_configured') else False, | |
| "features": [ | |
| "Role-Based Medical Analysis", | |
| "Domain-Specific Research", | |
| "User Role Adaptation", | |
| "Paper Summarization", | |
| "Guideline Detection", | |
| "Simple Query Handling" | |
| ], | |
| "stats": { | |
| "domains_count": len(MEDICAL_DOMAINS), | |
| "user_roles_count": len(USER_CONTEXTS), | |
| "active_sessions": len(user_sessions), | |
| "active_connections": len(active_connections) | |
| } | |
| } | |
| async def get_domains(): | |
| """Get all available medical domains""" | |
| return { | |
| "success": True, | |
| "domains": MEDICAL_DOMAINS, | |
| "count": len(MEDICAL_DOMAINS), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def get_domain_info(domain_id: str): | |
| """Get information about a specific domain""" | |
| domain = get_domain_by_id(domain_id) | |
| if not domain: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": f"Domain '{domain_id}' not found"} | |
| ) | |
| return { | |
| "success": True, | |
| "domain": domain, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def get_user_roles(): | |
| """Get all available user roles""" | |
| return { | |
| "success": True, | |
| "user_roles": USER_CONTEXTS, | |
| "count": len(USER_CONTEXTS), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def get_user_role_info(role_id: str): | |
| """Get information about a specific user role""" | |
| role = get_user_role_by_id(role_id) | |
| if not role: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": f"User role '{role_id}' not found"} | |
| ) | |
| return { | |
| "success": True, | |
| "user_role": role, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def create_session(request: SessionCreate = None): | |
| """Create a new chat session with role-based reasoning""" | |
| if request is None: | |
| request = SessionCreate() | |
| session_id = request.session_id or str(uuid.uuid4()) | |
| user_role = validate_user_role(request.user_role) | |
| user_sessions[session_id] = { | |
| "id": session_id, | |
| "created_at": datetime.now().isoformat(), | |
| "user_role": user_role, | |
| "custom_role_prompt": request.custom_role_prompt, | |
| "message_count": 0, | |
| "domains_used": set(), | |
| "last_active": datetime.now().isoformat() | |
| } | |
| # Initialize engine for this session | |
| if hasattr(chat_engine, 'initialize_session'): | |
| chat_engine.initialize_session(session_id) | |
| role_info = get_user_role_by_id(user_role) | |
| return { | |
| "session_id": session_id, | |
| "user_role": user_role, | |
| "custom_role_prompt": request.custom_role_prompt, | |
| "role_info": role_info, | |
| "created_at": user_sessions[session_id]["created_at"], | |
| "welcome_message": f"""π **Welcome to Medical Research Assistant!** 𧬠| |
| π€ **Your role:** {role_info['name'] if role_info else user_role} {role_info['icon'] if role_info else 'π€'} | |
| π₯ **Available Specialties:** {len(MEDICAL_DOMAINS) - 2} medical domains | |
| I can help you with: | |
| β’ **Medical Research Analysis** - Evidence-based insights | |
| β’ **Paper Summarization** - Key findings and implications | |
| β’ **Clinical Context** - Domain-specific applications | |
| β’ **Research Gap Identification** - Future directions | |
| **Role-Specific Features:** | |
| - Tailored responses based on your role | |
| - Appropriate terminology for your expertise level | |
| - Relevant practical implications | |
| - {"Clinical guideline references" if user_role in ['clinician', 'doctor'] else "Appropriate level of detail"} | |
| **Try asking:** | |
| β’ "Latest treatments for diabetes" (Endocrinology) | |
| β’ "Research on Alzheimer's biomarkers" (Neurology) | |
| β’ "Clinical guidelines for hypertension" (Cardiology) | |
| β’ "Summarize recent advances in cancer immunotherapy" (Oncology) | |
| I'll adapt my responses to your specific needs as a {role_info['name'].lower() if role_info else user_role}.""" | |
| } | |
| async def chat_endpoint(request: ChatRequest): | |
| """Process chat message with role-based reasoning""" | |
| try: | |
| # Validate inputs | |
| domain = validate_domain(request.domain) | |
| user_role = validate_user_role(request.user_role) | |
| # Validate max_papers | |
| if request.max_papers < 1 or request.max_papers > 50: | |
| request.max_papers = min(max(request.max_papers, 1), 50) | |
| # Update session activity | |
| if request.session_id in user_sessions: | |
| session = user_sessions[request.session_id] | |
| session["last_active"] = datetime.now().isoformat() | |
| session["message_count"] += 1 | |
| session["domains_used"].add(domain) | |
| # Use session user_role if available | |
| if session.get("user_role"): | |
| user_role = session["user_role"] | |
| else: | |
| session["user_role"] = user_role | |
| # Update custom role prompt if provided | |
| if request.custom_role_prompt: | |
| session["custom_role_prompt"] = request.custom_role_prompt | |
| logger.info(f"Processing chat - Domain: {domain}, Role: {user_role}") | |
| # Process the query | |
| start_time = datetime.now() | |
| # Build kwargs for the engine | |
| engine_kwargs = { | |
| "query": request.message, | |
| "domain": domain, | |
| "session_id": request.session_id, | |
| "user_role": user_role, # Pass user_role parameter | |
| "max_papers": request.max_papers, | |
| } | |
| # Add optional parameters | |
| if request.custom_role_prompt: | |
| engine_kwargs["custom_role_prompt"] = request.custom_role_prompt | |
| if hasattr(request, 'use_real_time') and request.use_real_time is not None: | |
| engine_kwargs["use_real_time"] = request.use_real_time | |
| if hasattr(request, 'use_fallback') and request.use_fallback is not None: | |
| engine_kwargs["use_fallback"] = request.use_fallback | |
| response = await chat_engine.process_query_async(**engine_kwargs) | |
| processing_time = (datetime.now() - start_time).total_seconds() | |
| # Track query type | |
| if request.session_id in user_sessions: | |
| query_type = response.get("query_type", "unknown") | |
| if "query_types" not in user_sessions[request.session_id]: | |
| user_sessions[request.session_id]["query_types"] = [] | |
| user_sessions[request.session_id]["query_types"].append(query_type) | |
| # Format response | |
| return ChatResponse( | |
| success=True, | |
| message=response.get("answer", "No response generated"), | |
| session_id=request.session_id, | |
| processing_time=processing_time, | |
| confidence_score=response.get("confidence_score", {}).get("overall_score", 0), | |
| papers_used=response.get("papers_used", 0), | |
| real_papers=response.get("real_papers_used", 0), | |
| demo_papers=response.get("demo_papers_used", 0), | |
| user_role=response.get("user_role", user_role), | |
| domain=domain, | |
| reasoning_method=response.get("reasoning_method", "role_based"), | |
| raw_response=response | |
| ) | |
| except Exception as e: | |
| logger.error(f"Chat endpoint error: {str(e)}", exc_info=True) | |
| return ChatResponse( | |
| success=False, | |
| message=f"β Error: {str(e)}", | |
| session_id=request.session_id, | |
| error=str(e), | |
| user_role=request.user_role | |
| ) | |
| async def websocket_chat(websocket: WebSocket): | |
| """WebSocket for real-time chat with role-based reasoning""" | |
| await websocket.accept() | |
| session_id = None | |
| user_role = "general" | |
| try: | |
| while True: | |
| # Receive message | |
| data = await websocket.receive_json() | |
| message_type = data.get("type") | |
| if message_type == "init_session": | |
| # Create or get session | |
| session_id = data.get("session_id") or str(uuid.uuid4()) | |
| user_role = validate_user_role(data.get("user_role", "general")) | |
| if session_id not in user_sessions: | |
| user_sessions[session_id] = { | |
| "id": session_id, | |
| "created_at": datetime.now().isoformat(), | |
| "user_role": user_role, | |
| "custom_role_prompt": data.get("custom_role_prompt"), | |
| "message_count": 0, | |
| "websocket": websocket | |
| } | |
| if hasattr(chat_engine, 'initialize_session'): | |
| chat_engine.initialize_session(session_id) | |
| active_connections[session_id] = websocket | |
| role_info = get_user_role_by_id(user_role) | |
| await websocket.send_json({ | |
| "type": "session_created", | |
| "session_id": session_id, | |
| "user_role": user_role, | |
| "role_info": role_info, | |
| "custom_role_prompt": data.get("custom_role_prompt"), | |
| "timestamp": datetime.now().isoformat(), | |
| "features": [ | |
| "role_based_medical_research", | |
| "domain_specific_insights", | |
| "guideline_detection", | |
| "simple_query_handling" | |
| ], | |
| "stats": { | |
| "domains_available": len(MEDICAL_DOMAINS), | |
| "user_roles_available": len(USER_CONTEXTS) | |
| } | |
| }) | |
| elif message_type == "message" and session_id: | |
| # Process chat message | |
| user_message = data.get("message", "") | |
| domain = validate_domain(data.get("domain", "general_medical")) | |
| user_role = validate_user_role(data.get("user_role", user_role)) | |
| custom_role_prompt = data.get("custom_role_prompt") | |
| # Update session context | |
| if session_id in user_sessions: | |
| user_sessions[session_id]["user_role"] = user_role | |
| if custom_role_prompt: | |
| user_sessions[session_id]["custom_role_prompt"] = custom_role_prompt | |
| # Send typing indicator | |
| await websocket.send_json({ | |
| "type": "typing", | |
| "is_typing": True | |
| }) | |
| # Process in background | |
| asyncio.create_task( | |
| process_websocket_message( | |
| websocket, session_id, user_message, | |
| domain, user_role, custom_role_prompt, data | |
| ) | |
| ) | |
| elif message_type == "update_role" and session_id: | |
| # Update user role | |
| new_role = validate_user_role(data.get("user_role", "general")) | |
| user_role = new_role | |
| if session_id in user_sessions: | |
| user_sessions[session_id]["user_role"] = new_role | |
| role_info = get_user_role_by_id(new_role) | |
| await websocket.send_json({ | |
| "type": "role_updated", | |
| "user_role": user_role, | |
| "role_info": role_info, | |
| "session_id": session_id | |
| }) | |
| elif message_type == "update_domain" and session_id: | |
| # Update domain | |
| new_domain = validate_domain(data.get("domain", "general_medical")) | |
| domain_info = get_domain_by_id(new_domain) | |
| await websocket.send_json({ | |
| "type": "domain_updated", | |
| "domain": new_domain, | |
| "domain_info": domain_info, | |
| "session_id": session_id | |
| }) | |
| elif message_type == "clear_history" and session_id: | |
| # Clear chat history | |
| if hasattr(chat_engine, 'clear_memory'): | |
| chat_engine.clear_memory() | |
| await websocket.send_json({ | |
| "type": "history_cleared", | |
| "session_id": session_id | |
| }) | |
| elif message_type == "get_domains": | |
| # Send domain list | |
| await websocket.send_json({ | |
| "type": "domains_list", | |
| "domains": MEDICAL_DOMAINS, | |
| "count": len(MEDICAL_DOMAINS) | |
| }) | |
| elif message_type == "get_roles": | |
| # Send user roles list | |
| await websocket.send_json({ | |
| "type": "roles_list", | |
| "user_roles": USER_CONTEXTS, | |
| "count": len(USER_CONTEXTS) | |
| }) | |
| except WebSocketDisconnect: | |
| if session_id and session_id in active_connections: | |
| del active_connections[session_id] | |
| logger.info(f"WebSocket disconnected: {session_id}") | |
| except Exception as e: | |
| logger.error(f"WebSocket error: {str(e)}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Connection error: {str(e)}" | |
| }) | |
| async def process_websocket_message(websocket: WebSocket, session_id: str, | |
| user_message: str, domain: str, | |
| user_role: str, custom_role_prompt: str, data: dict): | |
| """Process WebSocket message asynchronously with role-based reasoning""" | |
| try: | |
| # Build engine parameters | |
| engine_kwargs = { | |
| "query": user_message, | |
| "domain": domain, | |
| "session_id": session_id, | |
| "user_role": user_role, | |
| "max_papers": data.get("max_papers", 15), | |
| } | |
| # Add optional parameters | |
| if custom_role_prompt: | |
| engine_kwargs["custom_role_prompt"] = custom_role_prompt | |
| if data.get("use_real_time") is not None: | |
| engine_kwargs["use_real_time"] = data.get("use_real_time") | |
| if data.get("use_fallback") is not None: | |
| engine_kwargs["use_fallback"] = data.get("use_fallback") | |
| # Process query | |
| response = await chat_engine.process_query_async(**engine_kwargs) | |
| # Send domain and role info | |
| domain_info = get_domain_by_id(domain) | |
| role_info = get_user_role_by_id(user_role) | |
| await websocket.send_json({ | |
| "type": "context_info", | |
| "user_role": response.get("user_role", user_role), | |
| "domain": domain, | |
| "domain_info": domain_info, | |
| "role_info": role_info, | |
| "reasoning_method": response.get("reasoning_method", "role_based") | |
| }) | |
| # Send response in chunks (for streaming effect) | |
| answer = response.get("answer", "") | |
| chunks = split_into_chunks(answer, 200) | |
| for i, chunk in enumerate(chunks): | |
| await websocket.send_json({ | |
| "type": "message_chunk", | |
| "chunk": chunk, | |
| "is_final": i == len(chunks) - 1, | |
| "chunk_index": i, | |
| "total_chunks": len(chunks) | |
| }) | |
| await asyncio.sleep(0.05) # Small delay for streaming effect | |
| # Send complete message with metadata | |
| await websocket.send_json({ | |
| "type": "message_complete", | |
| "message": answer, | |
| "metadata": { | |
| "confidence_score": response.get("confidence_score", {}).get("overall_score", 0), | |
| "papers_used": response.get("papers_used", 0), | |
| "real_papers": response.get("real_papers_used", 0), | |
| "demo_papers": response.get("demo_papers_used", 0), | |
| "user_role": response.get("user_role", user_role), | |
| "domain": domain, | |
| "reasoning_method": response.get("reasoning_method", "role_based"), | |
| "query_type": response.get("query_type", "general") | |
| } | |
| }) | |
| except Exception as e: | |
| logger.error(f"WebSocket message processing error: {str(e)}", exc_info=True) | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Processing error: {str(e)}" | |
| }) | |
| async def get_session_info(session_id: str): | |
| """Get session information with role-based data""" | |
| if session_id not in user_sessions: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": "Session not found"} | |
| ) | |
| session = user_sessions[session_id] | |
| # Get domain and role info | |
| domain_info = None | |
| if session.get("domains_used"): | |
| last_domain = list(session.get("domains_used"))[-1] if session.get("domains_used") else None | |
| domain_info = get_domain_by_id(last_domain) if last_domain else None | |
| role_info = get_user_role_by_id(session.get("user_role", "general")) | |
| return { | |
| "session_id": session_id, | |
| "created_at": session.get("created_at"), | |
| "user_role": session.get("user_role", "general"), | |
| "custom_role_prompt": session.get("custom_role_prompt"), | |
| "role_info": role_info, | |
| "message_count": session.get("message_count", 0), | |
| "last_active": session.get("last_active"), | |
| "domains_used": list(session.get("domains_used", [])), | |
| "last_domain_info": domain_info, | |
| "query_types": session.get("query_types", []), | |
| "features_enabled": [ | |
| "role_based_medical_research", | |
| "domain_specific_insights", | |
| "user_role_adaptation", | |
| "guideline_detection", | |
| "simple_query_handling" | |
| ] | |
| } | |
| async def update_session_role(session_id: str, request: dict): | |
| """Update session user role""" | |
| if session_id not in user_sessions: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"error": "Session not found"} | |
| ) | |
| new_role = validate_user_role(request.get("user_role", "general")) | |
| user_sessions[session_id]["user_role"] = new_role | |
| if request.get("custom_role_prompt"): | |
| user_sessions[session_id]["custom_role_prompt"] = request.get("custom_role_prompt") | |
| role_info = get_user_role_by_id(new_role) | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "user_role": new_role, | |
| "custom_role_prompt": user_sessions[session_id].get("custom_role_prompt"), | |
| "role_info": role_info, | |
| "message": f"User role updated to {new_role}" | |
| } | |
| async def delete_session(session_id: str): | |
| """Delete a session""" | |
| if session_id in user_sessions: | |
| # Clear engine memory if method exists | |
| if hasattr(chat_engine, 'clear_memory'): | |
| chat_engine.clear_memory() | |
| # Remove from storage | |
| del user_sessions[session_id] | |
| # Close WebSocket if active | |
| if session_id in active_connections: | |
| try: | |
| await active_connections[session_id].close() | |
| except: | |
| pass | |
| del active_connections[session_id] | |
| return {"success": True, "message": "Session deleted"} | |
| async def get_engine_status(): | |
| """Get engine status and metrics""" | |
| if hasattr(chat_engine, 'get_engine_status'): | |
| status = chat_engine.get_engine_status() | |
| return { | |
| "success": True, | |
| "engine": "Medical Research Engine with Role-Based Reasoning", | |
| "version": "2.2.0", | |
| "domains_supported": len(MEDICAL_DOMAINS), | |
| "user_roles_supported": len(USER_CONTEXTS), | |
| **status | |
| } | |
| return { | |
| "success": False, | |
| "engine": "Unknown", | |
| "message": "Engine status not available" | |
| } | |
| # ============================================================================ | |
| # DEVELOPMENT ONLY - Local server run | |
| # ============================================================================ | |
| if __name__ == "__main__" and os.getenv("VERCEL") is None: | |
| # Only run locally, not on Vercel | |
| import uvicorn | |
| print(f"\n{'=' * 60}") | |
| print(f"π STARTING MEDICAL RESEARCH AI SERVER (LOCAL)") | |
| print(f" Version: 2.2.0 - Role-Based Reasoning") | |
| print(f"{'=' * 60}") | |
| print(f"π API Docs: http://localhost:8000/api/docs") | |
| print(f"π₯ Medical Domains: {len(MEDICAL_DOMAINS)}") | |
| print(f"π€ User Roles: {len(USER_CONTEXTS)}") | |
| print(f"π§ Features: Role-based reasoning, Simple query handling") | |
| print(f"{'=' * 60}\n") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True | |
| ) |