"""
Medical Research API Server - HuggingFace Compatible Version
"""
# Add this for HuggingFace compatibility
import os
import sys
# Set HuggingFace specific settings
if os.getenv('SPACE_ID'):
# We're on HuggingFace Spaces
print(f"๐ Running on HuggingFace Space: {os.getenv('SPACE_ID')}")
# Force port 7860 for HuggingFace
os.environ['PORT'] = '7860'
# Update CORS for HuggingFace
ALLOWED_ORIGINS = [
"https://medical-research-ai.vercel.app",
"http://localhost:3000",
"https://paulhemb-medsearchpro.hf.space"
]
else:
ALLOWED_ORIGINS = ["*"]
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Import engine - Vercel compatible
try:
# Try relative import first (Vercel runs from api/ directory)
from engine import MedicalResearchEngine
except ImportError:
try:
# Try absolute import for local development
from api.engine import MedicalResearchEngine
except ImportError:
print("โ ๏ธ MedicalResearchEngine not found - using fallback mode")
# Fallback class
class MedicalResearchEngine:
def __init__(self):
pass
async def process_query_async(self, **kwargs):
return {"answer": "Engine not available", "papers_used": 0}
# ============================================================================
# DOMAIN AND USER CONTEXT DEFINITIONS
# ============================================================================
MEDICAL_DOMAINS = [
{"id": "internal_medicine", "name": "Internal Medicine", "icon": "๐ฅ",
"description": "General internal medicine and diagnosis"},
{"id": "endocrinology", "name": "Endocrinology", "icon": "๐งฌ",
"description": "Hormonal and metabolic disorders"},
{"id": "gastroenterology", "name": "Gastroenterology", "icon": "๐ฉธ",
"description": "Digestive system disorders"},
{"id": "pulmonology", "name": "Pulmonology", "icon": "๐ซ",
"description": "Respiratory diseases and lung disorders"},
{"id": "nephrology", "name": "Nephrology", "icon": "๐งช",
"description": "Kidney diseases and renal function"},
{"id": "hematology", "name": "Hematology", "icon": "๐ฉธ",
"description": "Blood disorders and hematologic diseases"},
{"id": "infectious_disease", "name": "Infectious Diseases", "icon": "๐ฆ ",
"description": "Infectious diseases and microbiology"},
{"id": "obstetrics_gynecology", "name": "Obstetrics & Gynecology", "icon": "๐คฐ",
"description": "Women's health, pregnancy and reproductive medicine"},
{"id": "pathology", "name": "Pathology", "icon": "๐ฌ",
"description": "Disease diagnosis through tissue examination"},
{"id": "laboratory_medicine", "name": "Laboratory Medicine", "icon": "๐งช",
"description": "Clinical laboratory testing and biomarkers"},
{"id": "bioinformatics", "name": "Bioinformatics", "icon": "๐ป",
"description": "Computational analysis of biological data"},
{"id": "clinical_research", "name": "Clinical Research", "icon": "๐",
"description": "Clinical trials and evidence-based medicine"},
{"id": "medical_imaging", "name": "Medical Imaging", "icon": "๐ฉป",
"description": "Medical imaging and radiology"},
{"id": "oncology", "name": "Oncology", "icon": "๐ฆ ",
"description": "Cancer research and treatment"},
{"id": "cardiology", "name": "Cardiology", "icon": "โค๏ธ",
"description": "Heart and cardiovascular diseases"},
{"id": "neurology", "name": "Neurology", "icon": "๐ง ",
"description": "Brain and nervous system disorders"},
{"id": "pharmacology", "name": "Pharmacology", "icon": "๐",
"description": "Drug therapy and medication management"},
{"id": "genomics", "name": "Genomics", "icon": "๐งฌ",
"description": "Genetic research and personalized medicine"},
{"id": "public_health", "name": "Public Health", "icon": "๐",
"description": "Population health and epidemiology"},
{"id": "surgery", "name": "Surgery", "icon": "โ๏ธ",
"description": "Surgical procedures and techniques"},
{"id": "pediatrics", "name": "Pediatrics", "icon": "๐ถ",
"description": "Child health and pediatric medicine"},
{"id": "psychiatry", "name": "Psychiatry", "icon": "๐ง ",
"description": "Mental health and psychiatric disorders"},
{"id": "dermatology", "name": "Dermatology", "icon": "๐ฆ",
"description": "Skin diseases and dermatologic conditions"},
{"id": "orthopedics", "name": "Orthopedics", "icon": "๐ฆด",
"description": "Musculoskeletal disorders and bone health"},
{"id": "ophthalmology", "name": "Ophthalmology", "icon": "๐๏ธ",
"description": "Eye diseases and vision care"},
{"id": "urology", "name": "Urology", "icon": "๐ง",
"description": "Urinary system and male reproductive health"},
{"id": "emergency_medicine", "name": "Emergency Medicine", "icon": "๐",
"description": "Acute care and emergency response"},
{"id": "critical_care", "name": "Critical Care", "icon": "๐ฅ",
"description": "Intensive care and critical illness"},
{"id": "pain_medicine", "name": "Pain Medicine", "icon": "โ๏ธ",
"description": "Pain management and analgesia"},
{"id": "nutrition", "name": "Nutrition", "icon": "๐ฅ",
"description": "Clinical nutrition and dietary management"},
{"id": "allergy_immunology", "name": "Allergy & Immunology", "icon": "๐คง",
"description": "Allergic diseases and immune disorders"},
{"id": "rehabilitation_medicine", "name": "Rehabilitation Medicine", "icon": "โฟ",
"description": "Physical therapy and recovery"},
{"id": "general_medical", "name": "General Medical", "icon": "โ๏ธ",
"description": "General medical research and clinical questions"},
{"id": "auto", "name": "Auto-detect", "icon": "๐ค",
"description": "Automatically detect domain from query"}
]
USER_CONTEXTS = [
{"id": "auto", "name": "Auto-detect", "icon": "๐ค",
"description": "Automatically detect user context"},
{"id": "clinician", "name": "Clinician", "icon": "๐จโโ๏ธ",
"description": "Medical doctors, nurses, and healthcare providers"},
{"id": "researcher", "name": "Researcher", "icon": "๐ฌ",
"description": "Academic researchers and scientists"},
{"id": "student", "name": "Student", "icon": "๐",
"description": "Medical students and trainees"},
{"id": "administrator", "name": "Administrator", "icon": "๐ผ",
"description": "Healthcare administrators and managers"},
{"id": "patient", "name": "Patient", "icon": "๐ค",
"description": "Patients and general public"},
{"id": "general", "name": "General", "icon": "๐ค",
"description": "General audience"}
]
VALID_DOMAINS: Set[str] = {domain["id"] for domain in MEDICAL_DOMAINS}
VALID_USER_CONTEXTS: Set[str] = {context["id"] for context in USER_CONTEXTS}
# ============================================================================
# PYDANTIC MODELS
# ============================================================================
class SessionCreate(BaseModel):
"""Schema for creating a new session"""
session_id: Optional[str] = None
user_context: str = "auto"
class ChatRequest(BaseModel):
"""Schema for chat request"""
message: str
session_id: str
domain: Optional[str] = "general_medical"
user_context: str = "auto"
max_papers: int = 15
class ChatResponse(BaseModel):
"""Schema for chat response"""
success: bool
message: str
session_id: str
processing_time: Optional[float] = None
confidence_score: Optional[float] = None
papers_used: Optional[int] = None
user_context: Optional[str] = None
raw_response: Optional[Dict] = None
error: Optional[str] = None
# ============================================================================
# FASTAPI APP INITIALIZATION
# ============================================================================
app = FastAPI(
title="Medical Research AI",
description="Medical Research Assistant with Evidence-Based Analysis",
version="1.0.0",
docs_url="/api/docs",
redoc_url="/api/redoc"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["ALLOWED_ORIGINS"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files (only if directory exists)
static_dir = "static"
if os.path.exists(static_dir):
app.mount("/static", StaticFiles(directory=static_dir), name="static")
else:
logger.warning(f"Static directory '{static_dir}' not found")
# Templates (only if directory exists)
templates_dir = "templates"
if os.path.exists(templates_dir):
templates = Jinja2Templates(directory=templates_dir)
else:
templates = None
logger.warning(f"Templates directory '{templates_dir}' not found")
# Initialize chat engine
chat_engine = MedicalResearchEngine()
# Active WebSocket connections
active_connections: Dict[str, WebSocket] = {}
# Session storage
user_sessions: Dict[str, Dict] = {}
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def validate_domain(domain: str) -> str:
"""Validate and normalize domain"""
if domain not in VALID_DOMAINS:
logger.warning(f"Invalid domain '{domain}', defaulting to 'general_medical'")
return "general_medical"
return domain
def validate_user_context(user_context: str) -> str:
"""Validate and normalize user context"""
if user_context not in VALID_USER_CONTEXTS:
logger.warning(f"Invalid user_context '{user_context}', defaulting to 'auto'")
return "auto"
return user_context
def get_domain_by_id(domain_id: str) -> Optional[Dict]:
"""Get domain info by ID"""
for domain in MEDICAL_DOMAINS:
if domain["id"] == domain_id:
return domain
return None
def get_user_context_by_id(context_id: str) -> Optional[Dict]:
"""Get user context info by ID"""
for context in USER_CONTEXTS:
if context["id"] == context_id:
return context
return None
def split_into_chunks(text: str, chunk_size: int = 200) -> List[str]:
"""Split text into chunks for streaming"""
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
# ============================================================================
# ROUTES
# ============================================================================
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
"""Serve the chat interface"""
if templates:
return templates.TemplateResponse("index.html", {"request": request})
# Fallback HTML if templates not found
html_content = """
Medical Research AI
"""
return HTMLResponse(content=html_content)
@app.get("/api/health")
async def health_check():
"""Health check endpoint"""
engine_status = chat_engine.get_engine_status() if hasattr(chat_engine, 'get_engine_status') else {}
return {
"status": "healthy",
"engine": "Medical Research Engine",
"version": "1.0.0",
"timestamp": datetime.now().isoformat(),
"engine_configured": chat_engine.api_configured if hasattr(chat_engine, 'api_configured') else False,
"features": [
"Evidence-Based Medical Analysis",
"Domain-Specific Research",
"User Context Adaptation",
"Paper Summarization"
],
"stats": {
"domains_count": len(MEDICAL_DOMAINS),
"user_contexts_count": len(USER_CONTEXTS),
"active_sessions": len(user_sessions),
"active_connections": len(active_connections)
}
}
@app.get("/api/v1/domains")
async def get_domains():
"""Get all available medical domains"""
return {
"success": True,
"domains": MEDICAL_DOMAINS,
"count": len(MEDICAL_DOMAINS),
"timestamp": datetime.now().isoformat()
}
@app.get("/api/v1/domains/{domain_id}")
async def get_domain_info(domain_id: str):
"""Get information about a specific domain"""
domain = get_domain_by_id(domain_id)
if not domain:
return JSONResponse(
status_code=404,
content={"error": f"Domain '{domain_id}' not found"}
)
return {
"success": True,
"domain": domain,
"timestamp": datetime.now().isoformat()
}
@app.get("/api/v1/user_contexts")
async def get_user_contexts():
"""Get all available user contexts"""
return {
"success": True,
"user_contexts": USER_CONTEXTS,
"count": len(USER_CONTEXTS),
"timestamp": datetime.now().isoformat()
}
@app.post("/api/v1/session/create")
async def create_session(request: SessionCreate = None):
"""Create a new chat session"""
if request is None:
request = SessionCreate()
session_id = request.session_id or str(uuid.uuid4())
user_context = validate_user_context(request.user_context)
user_sessions[session_id] = {
"id": session_id,
"created_at": datetime.now().isoformat(),
"user_context": user_context,
"message_count": 0,
"domains_used": set(),
"last_active": datetime.now().isoformat()
}
# Initialize engine for this session
if hasattr(chat_engine, 'initialize_session'):
chat_engine.initialize_session(session_id)
context_info = get_user_context_by_id(user_context)
return {
"session_id": session_id,
"user_context": user_context,
"context_info": context_info,
"created_at": user_sessions[session_id]["created_at"],
"welcome_message": f"""๐ **Welcome to Medical Research Assistant!** ๐งฌ
๐ค **Your session context:** {context_info['name'] if context_info else user_context}
๐ฅ **Available Specialties:** {len(MEDICAL_DOMAINS) - 2} medical domains
I can help you with:
โข **Medical Research Analysis** - Evidence-based insights
โข **Paper Summarization** - Key findings and implications
โข **Clinical Context** - Domain-specific applications
โข **Research Gap Identification** - Future directions
**Try asking:**
โข "Latest treatments for diabetes" (Endocrinology)
โข "Research on Alzheimer's biomarkers" (Neurology)
โข "Clinical guidelines for hypertension" (Cardiology)
โข "Summarize recent advances in cancer immunotherapy" (Oncology)
I'll adapt my responses based on your role and medical domain."""
}
@app.post("/api/v1/chat")
async def chat_endpoint(request: ChatRequest):
"""Process chat message"""
try:
# Validate inputs
domain = validate_domain(request.domain)
user_context = validate_user_context(request.user_context)
# Validate max_papers
if request.max_papers < 1 or request.max_papers > 50:
request.max_papers = min(max(request.max_papers, 1), 50)
# Update session activity
if request.session_id in user_sessions:
session = user_sessions[request.session_id]
session["last_active"] = datetime.now().isoformat()
session["message_count"] += 1
session["domains_used"].add(domain)
# Use session user_context if available
if session.get("user_context"):
user_context = session["user_context"]
else:
session["user_context"] = user_context
logger.info(f"Processing chat - Domain: {domain}, Context: {user_context}")
# Process the query
start_time = datetime.now()
response = await chat_engine.process_query_async(
query=request.message,
domain=domain,
session_id=request.session_id,
user_context=user_context,
max_papers=request.max_papers
)
processing_time = (datetime.now() - start_time).total_seconds()
# Track query type
if request.session_id in user_sessions:
query_type = response.get("query_type", "unknown")
if "query_types" not in user_sessions[request.session_id]:
user_sessions[request.session_id]["query_types"] = []
user_sessions[request.session_id]["query_types"].append(query_type)
# Format response
return ChatResponse(
success=True,
message=response.get("answer", "No response generated"),
session_id=request.session_id,
processing_time=processing_time,
confidence_score=response.get("confidence_score", {}).get("overall_score", 0),
papers_used=response.get("papers_used", 0),
user_context=response.get("user_context", user_context),
raw_response=response
)
except Exception as e:
logger.error(f"Chat endpoint error: {str(e)}", exc_info=True)
return ChatResponse(
success=False,
message=f"โ Error: {str(e)}",
session_id=request.session_id,
error=str(e),
user_context=request.user_context
)
@app.websocket("/ws/chat")
async def websocket_chat(websocket: WebSocket):
"""WebSocket for real-time chat"""
await websocket.accept()
session_id = None
user_context = "auto"
try:
while True:
# Receive message
data = await websocket.receive_json()
message_type = data.get("type")
if message_type == "init_session":
# Create or get session
session_id = data.get("session_id") or str(uuid.uuid4())
user_context = validate_user_context(data.get("user_context", "auto"))
if session_id not in user_sessions:
user_sessions[session_id] = {
"id": session_id,
"created_at": datetime.now().isoformat(),
"user_context": user_context,
"message_count": 0,
"websocket": websocket
}
if hasattr(chat_engine, 'initialize_session'):
chat_engine.initialize_session(session_id)
active_connections[session_id] = websocket
await websocket.send_json({
"type": "session_created",
"session_id": session_id,
"user_context": user_context,
"timestamp": datetime.now().isoformat(),
"features": [
"medical_research_analysis",
"domain_specific_insights",
"user_context_adaptation"
],
"stats": {
"domains_available": len(MEDICAL_DOMAINS),
"user_contexts_available": len(USER_CONTEXTS)
}
})
elif message_type == "message" and session_id:
# Process chat message
user_message = data.get("message", "")
domain = validate_domain(data.get("domain", "general_medical"))
user_context = validate_user_context(data.get("user_context", user_context))
# Update session context
if session_id in user_sessions:
user_sessions[session_id]["user_context"] = user_context
# Send typing indicator
await websocket.send_json({
"type": "typing",
"is_typing": True
})
# Process in background
asyncio.create_task(
process_websocket_message(
websocket, session_id, user_message,
domain, user_context, data
)
)
elif message_type == "update_context" and session_id:
# Update user context
new_context = validate_user_context(data.get("user_context", "auto"))
user_context = new_context
if session_id in user_sessions:
user_sessions[session_id]["user_context"] = new_context
context_info = get_user_context_by_id(new_context)
await websocket.send_json({
"type": "context_updated",
"user_context": user_context,
"context_info": context_info,
"session_id": session_id
})
elif message_type == "update_domain" and session_id:
# Update domain
new_domain = validate_domain(data.get("domain", "general_medical"))
domain_info = get_domain_by_id(new_domain)
await websocket.send_json({
"type": "domain_updated",
"domain": new_domain,
"domain_info": domain_info,
"session_id": session_id
})
elif message_type == "clear_history" and session_id:
# Clear chat history
if hasattr(chat_engine, 'clear_memory'):
chat_engine.clear_memory()
await websocket.send_json({
"type": "history_cleared",
"session_id": session_id
})
elif message_type == "get_domains":
# Send domain list
await websocket.send_json({
"type": "domains_list",
"domains": MEDICAL_DOMAINS,
"count": len(MEDICAL_DOMAINS)
})
elif message_type == "get_contexts":
# Send user contexts list
await websocket.send_json({
"type": "contexts_list",
"user_contexts": USER_CONTEXTS,
"count": len(USER_CONTEXTS)
})
except WebSocketDisconnect:
if session_id and session_id in active_connections:
del active_connections[session_id]
logger.info(f"WebSocket disconnected: {session_id}")
except Exception as e:
logger.error(f"WebSocket error: {str(e)}")
await websocket.send_json({
"type": "error",
"message": f"Connection error: {str(e)}"
})
async def process_websocket_message(websocket: WebSocket, session_id: str,
user_message: str, domain: str,
user_context: str, data: dict):
"""Process WebSocket message asynchronously"""
try:
# Process query
response = await chat_engine.process_query_async(
query=user_message,
domain=domain,
session_id=session_id,
user_context=user_context,
max_papers=data.get("max_papers", 15)
)
# Send domain and context info
domain_info = get_domain_by_id(domain)
context_info = get_user_context_by_id(user_context)
await websocket.send_json({
"type": "context_info",
"user_context": response.get("user_context", user_context),
"domain": domain,
"domain_info": domain_info,
"context_info": context_info
})
# Send response in chunks (for streaming effect)
answer = response.get("answer", "")
chunks = split_into_chunks(answer, 200)
for i, chunk in enumerate(chunks):
await websocket.send_json({
"type": "message_chunk",
"chunk": chunk,
"is_final": i == len(chunks) - 1,
"chunk_index": i,
"total_chunks": len(chunks)
})
await asyncio.sleep(0.05) # Small delay for streaming effect
# Send complete message with metadata
await websocket.send_json({
"type": "message_complete",
"message": answer,
"metadata": {
"confidence_score": response.get("confidence_score", {}).get("overall_score", 0),
"papers_used": response.get("papers_used", 0),
"user_context": response.get("user_context", user_context),
"domain": domain,
"query_type": response.get("query_type", "general")
}
})
except Exception as e:
logger.error(f"WebSocket message processing error: {str(e)}", exc_info=True)
await websocket.send_json({
"type": "error",
"message": f"Processing error: {str(e)}"
})
@app.get("/api/v1/session/{session_id}")
async def get_session_info(session_id: str):
"""Get session information"""
if session_id not in user_sessions:
return JSONResponse(
status_code=404,
content={"error": "Session not found"}
)
session = user_sessions[session_id]
# Get domain and context info
domain_info = None
if session.get("domains_used"):
last_domain = list(session.get("domains_used"))[-1] if session.get("domains_used") else None
domain_info = get_domain_by_id(last_domain) if last_domain else None
context_info = get_user_context_by_id(session.get("user_context", "auto"))
return {
"session_id": session_id,
"created_at": session.get("created_at"),
"user_context": session.get("user_context", "auto"),
"context_info": context_info,
"message_count": session.get("message_count", 0),
"last_active": session.get("last_active"),
"domains_used": list(session.get("domains_used", [])),
"last_domain_info": domain_info,
"query_types": session.get("query_types", []),
"features_enabled": [
"medical_research_analysis",
"domain_specific_insights",
"user_context_adaptation"
]
}
@app.put("/api/v1/session/{session_id}/context")
async def update_session_context(session_id: str, request: dict):
"""Update session user context"""
if session_id not in user_sessions:
return JSONResponse(
status_code=404,
content={"error": "Session not found"}
)
new_context = validate_user_context(request.get("user_context", "auto"))
user_sessions[session_id]["user_context"] = new_context
context_info = get_user_context_by_id(new_context)
return {
"success": True,
"session_id": session_id,
"user_context": new_context,
"context_info": context_info,
"message": f"User context updated to {new_context}"
}
@app.delete("/api/v1/session/{session_id}")
async def delete_session(session_id: str):
"""Delete a session"""
if session_id in user_sessions:
# Clear engine memory if method exists
if hasattr(chat_engine, 'clear_memory'):
chat_engine.clear_memory()
# Remove from storage
del user_sessions[session_id]
# Close WebSocket if active
if session_id in active_connections:
try:
await active_connections[session_id].close()
except:
pass
del active_connections[session_id]
return {"success": True, "message": "Session deleted"}
@app.get("/api/v1/engine/status")
async def get_engine_status():
"""Get engine status and metrics"""
if hasattr(chat_engine, 'get_engine_status'):
status = chat_engine.get_engine_status()
return {
"success": True,
"engine": "Medical Research Engine",
"domains_supported": len(MEDICAL_DOMAINS),
"user_contexts_supported": len(USER_CONTEXTS),
**status
}
return {
"success": False,
"engine": "Unknown",
"message": "Engine status not available"
}
# ============================================================================
# DEVELOPMENT ONLY - Local server run
# ============================================================================
if __name__ == "__main__" and os.getenv("VERCEL") is None:
# Only run locally, not on Vercel
import uvicorn
print(f"\n{'=' * 60}")
print(f"๐ STARTING MEDICAL RESEARCH AI SERVER (LOCAL)")
print(f"{'=' * 60}")
print(f"๐ API Docs: http://localhost:8000/api/docs")
print(f"๐ฅ Medical Domains: {len(MEDICAL_DOMAINS)}")
print(f"๐ค User Contexts: {len(USER_CONTEXTS)}")
print(f"{'=' * 60}\n")
uvicorn.run(
app,
host="0.0.0.0",
port=8000,
reload=True
)