Spaces:
Paused
Paused
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from crewai import Task, Crew, Agent | |
| import os | |
| import sqlite3 | |
| from pathlib import Path | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from fastapi.openapi.docs import get_swagger_ui_html | |
| from src.DiagnosticInfoAgent import DiagnosticInfoAgent | |
| from src.DoctorInfoAgent import DoctorInfoAgent | |
| from src.EmergencyServicesAgent import EmergencyServicesAgent | |
| from src.HospitalComparisonAgent import HospitalComparisonAgent | |
| from src.constants import MODEL_NAME, OPENAI_API_KEY | |
| from textwrap import dedent | |
| from openai import OpenAI | |
| # β **Retrieve API Key from Hugging Face Secrets** | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise RuntimeError("β OPENAI_API_KEY not found in environment variables. Set it in Hugging Face secrets.") | |
| # β **Initialize OpenAI client (uses API key from environment)** | |
| client = OpenAI() | |
| # β **Initialize FastAPI app** | |
| app = FastAPI(title="Health-Sense AI") | |
| # β **Add CORS middleware** | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # β **Ensure database path is set correctly** | |
| BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| DB_PATH = os.path.join(BASE_DIR, "src", "appointments.db") | |
| EMERGENCY_DB_PATH = os.path.join(BASE_DIR, "src", "emergency.db") | |
| # β **Ensure CrewAI stores data inside a writable `/src/data` directory** | |
| CREWAI_STORAGE_PATH = "/src/data" | |
| os.environ["CREWAI_STORAGE_PATH"] = CREWAI_STORAGE_PATH | |
| os.environ["XDG_DATA_HOME"] = CREWAI_STORAGE_PATH # Prevents CrewAI from using `/.local/share` | |
| os.environ["CREWAI_DATABASE_PATH"] = os.path.join(CREWAI_STORAGE_PATH, "crewai_storage.db") | |
| # β **Override CrewAI's storage path function** | |
| import crewai.utilities.paths | |
| from crewai.memory.storage.kickoff_task_outputs_storage import KickoffTaskOutputsSQLiteStorage | |
| def db_storage_path(): | |
| return CREWAI_STORAGE_PATH # Force CrewAI to use `/src/data` | |
| crewai.utilities.paths.db_storage_path = db_storage_path # Override CrewAI storage function | |
| # β **Ensure `/src/data` exists and has the correct permissions** | |
| Path(CREWAI_STORAGE_PATH).mkdir(parents=True, exist_ok=True) | |
| print(f"πΉ CrewAI Storage Path: {CREWAI_STORAGE_PATH}") | |
| print(f"πΉ CrewAI Database Path: {os.environ['CREWAI_DATABASE_PATH']}") | |
| # β **Initialize CrewAI storage safely** | |
| try: | |
| KickoffTaskOutputsSQLiteStorage() | |
| except Exception as e: | |
| print(f"β οΈ ERROR: Could not initialize CrewAI storage: {e}") | |
| # β **Function to execute queries on SQLite** | |
| def query_db(db_path, query, args=()): | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| conn.row_factory = sqlite3.Row | |
| cur = conn.cursor() | |
| cur.execute(query, args) | |
| result = cur.fetchall() | |
| conn.close() | |
| return [dict(row) for row in result] | |
| except sqlite3.Error as e: | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") | |
| # β **Serve Static Files** | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # β **Serve `index.html` as homepage** | |
| async def serve_homepage(): | |
| return FileResponse("static/index.html") | |
| # β **Enable FastAPI Swagger Docs** | |
| async def custom_swagger_ui_html(): | |
| return get_swagger_ui_html( | |
| openapi_url="/openapi.json", | |
| title="HealthSense AI API Docs" | |
| ) | |
| # β **Retrieve Doctor Information** | |
| def get_doctors(): | |
| return query_db(DB_PATH, "SELECT * FROM doctors") | |
| # β **Retrieve Appointment Information** | |
| def get_appointments(): | |
| return query_db(DB_PATH, "SELECT * FROM appointments") | |
| # β **Set environment variables for AI models** | |
| #os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY | |
| os.environ["OPENAI_MODEL_NAME"] = MODEL_NAME | |
| # β **Initialize Agents** | |
| hospital_comparison_agent = HospitalComparisonAgent() | |
| doctor_info_agent = DoctorInfoAgent() | |
| emergency_services_agent = EmergencyServicesAgent() | |
| diagnostic_info_agent = DiagnosticInfoAgent() | |
| hospital_info_agent = hospital_comparison_agent.hospital_info_agent | |
| doctor_slots_agent = doctor_info_agent.doctor_slots_agent | |
| emergency_agent = emergency_services_agent.emergency_agent | |
| diagnostic_info_agent = diagnostic_info_agent.diagnostic_info_agent | |
| # β **Define Router Agent** | |
| router_agent = Agent( | |
| role="Researcher", | |
| goal="Research and analyze information", | |
| backstory="Expert at gathering and analyzing information", | |
| allow_delegation=True, | |
| ) | |
| # β **Define Input Model** | |
| class QueryInput(BaseModel): | |
| query: str | |
| # β **API Endpoint: Handle AI Queries** | |
| async def handle_query(input: QueryInput): | |
| query = input.query | |
| routing_task = Task( | |
| description=dedent(f""" | |
| Route this query: "{query}" | |
| Choose the most appropriate agent based on keywords and context. | |
| Return the output of the selected agent. | |
| """), | |
| expected_output="", | |
| agent=router_agent, | |
| ) | |
| multi_agent_crew = Crew( | |
| agents=[router_agent, hospital_info_agent, doctor_slots_agent, emergency_agent, diagnostic_info_agent], | |
| tasks=[routing_task], | |
| ) | |
| result_with_router = multi_agent_crew.kickoff() | |
| if not result_with_router.tasks_output: | |
| raise HTTPException(status_code=500, detail="No output generated by the agents.") | |
| responses = [task.raw for task in result_with_router.tasks_output] | |
| return {"query": query, "responses": responses} | |
| # β **API Endpoint: Health Check** | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| # β **Run FastAPI App (For Hugging Face)** | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |