Spaces:
Sleeping
Sleeping
| import logging | |
| import time | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| from typing import List, Dict, Any | |
| # Import necessary components from your kig_core library | |
| # Ensure kig_core is in the Python path or installed as a package | |
| try: | |
| from kig_core.config import settings # Loads config on import | |
| from kig_core.schemas import PlannerState, KeyIssue as KigKeyIssue, GraphConfig | |
| from kig_core.planner import build_graph | |
| from kig_core.graph_client import neo4j_client # Import the initialized client instance | |
| from langchain_core.messages import HumanMessage | |
| except ImportError as e: | |
| print(f"Error importing kig_core components: {e}") | |
| print("Please ensure kig_core is in your Python path or installed.") | |
| # You might want to exit or raise a clearer error if imports fail | |
| raise | |
| # Configure logging for the API | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # --- Pydantic Models for API Request/Response --- | |
| class KeyIssueRequest(BaseModel): | |
| """Request body containing the user's technical query.""" | |
| query: str | |
| class KeyIssueResponse(BaseModel): | |
| """Response body containing the generated key issues.""" | |
| key_issues: List[KigKeyIssue] # Use the KeyIssue schema from kig_core | |
| # --- Global Variables / State --- | |
| # Keep the graph instance global for efficiency if desired, | |
| # but consider potential concurrency issues if graph/LLMs have state. | |
| # Rebuilding on each request is safer for statelessness. | |
| app_graph = None # Will be initialized at startup | |
| # --- Application Lifecycle (Startup/Shutdown) --- | |
| async def lifespan(app: FastAPI): | |
| """Handles startup and shutdown events.""" | |
| global app_graph | |
| logger.info("API starting up...") | |
| # Initialize Neo4j client (already done on import by graph_client.py) | |
| # Verify connection (optional, already done by graph_client on init) | |
| try: | |
| logger.info("Verifying Neo4j connection...") | |
| neo4j_client._get_driver().verify_connectivity() | |
| logger.info("Neo4j connection verified.") | |
| except Exception as e: | |
| logger.error(f"Neo4j connection verification failed on startup: {e}", exc_info=True) | |
| # Decide if the app should fail to start | |
| # raise RuntimeError("Failed to connect to Neo4j on startup.") from e | |
| # Build the LangGraph application | |
| logger.info("Building LangGraph application...") | |
| try: | |
| app_graph = build_graph() | |
| logger.info("LangGraph application built successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to build LangGraph application on startup: {e}", exc_info=True) | |
| # Decide if the app should fail to start | |
| raise RuntimeError("Failed to build LangGraph on startup.") from e | |
| yield # API runs here | |
| # --- Shutdown --- | |
| logger.info("API shutting down...") | |
| # Close Neo4j connection (handled by atexit in graph_client.py) | |
| # neo4j_client.close() # Usually not needed due to atexit registration | |
| logger.info("Neo4j client closed (likely via atexit).") | |
| logger.info("API shutdown complete.") | |
| # --- FastAPI Application --- | |
| app = FastAPI( | |
| title="Key Issue Generator API", | |
| description="API to generate Key Issues based on a technical query using LLMs and Neo4j.", | |
| version="1.0.0", | |
| lifespan=lifespan # Use the lifespan context manager | |
| ) | |
| # --- API Endpoint --- | |
| # API state check route | |
| def read_root(): | |
| return {"status": "ok"} | |
| async def generate_issues(request: KeyIssueRequest): | |
| """ | |
| Accepts a technical query and returns a list of generated Key Issues. | |
| """ | |
| global app_graph | |
| if app_graph is None: | |
| logger.error("Graph application is not initialized.") | |
| raise HTTPException(status_code=503, detail="Service Unavailable: Graph not initialized") | |
| user_query = request.query | |
| if not user_query: | |
| raise HTTPException(status_code=400, detail="Query cannot be empty.") | |
| logger.info(f"Received request to generate key issues for query: '{user_query[:100]}...'") | |
| start_time = time.time() | |
| try: | |
| # --- Prepare Initial State for LangGraph --- | |
| # Note: Ensure PlannerState aligns with what build_graph expects | |
| initial_state: PlannerState = { | |
| "user_query": user_query, | |
| "messages": [HumanMessage(content=user_query)], | |
| "plan": [], | |
| "current_plan_step_index": -1, # Or as expected by your graph's entry point | |
| "step_outputs": {}, | |
| "key_issues": [], | |
| "error": None | |
| } | |
| # --- Define Configuration (e.g., Thread ID for Memory) --- | |
| # Using a simple thread ID; adapt if using persistent memory | |
| # import hashlib | |
| # thread_id = hashlib.sha256(user_query.encode()).hexdigest()[:8] | |
| # config: GraphConfig = {"configurable": {"thread_id": thread_id}} | |
| # If not using memory, config can be simpler or empty based on LangGraph version | |
| config: GraphConfig = {"configurable": {}} # Adjust if thread_id/memory is needed | |
| # --- Execute the LangGraph Workflow --- | |
| logger.info("Invoking LangGraph workflow...") | |
| # Use invoke for a single result, or stream if you need intermediate steps | |
| final_state = await app_graph.ainvoke(initial_state, config=config) | |
| # If using stream: | |
| # final_state = None | |
| # async for step_state in app_graph.astream(initial_state, config=config): | |
| # # Process intermediate states if needed | |
| # node_name = list(step_state.keys())[0] | |
| # logger.debug(f"Graph step completed: {node_name}") | |
| # final_state = step_state[node_name] # Get the latest full state output | |
| end_time = time.time() | |
| logger.info(f"Workflow finished in {end_time - start_time:.2f} seconds.") | |
| # --- Process Final Results --- | |
| if final_state is None: | |
| logger.error("Workflow execution did not produce a final state.") | |
| raise HTTPException(status_code=500, detail="Workflow execution failed to produce a result.") | |
| if final_state.get("error"): | |
| error_msg = final_state.get("error", "Unknown error") | |
| logger.error(f"Workflow failed with error: {error_msg}") | |
| # Map internal errors to appropriate HTTP status codes | |
| status_code = 500 # Internal Server Error by default | |
| if "Neo4j" in error_msg or "connection" in error_msg.lower(): | |
| status_code = 503 # Service Unavailable (database issue) | |
| elif "LLM error" in error_msg or "parse" in error_msg.lower(): | |
| status_code = 502 # Bad Gateway (issue with upstream LLM) | |
| raise HTTPException(status_code=status_code, detail=f"Workflow failed: {error_msg}") | |
| # --- Extract Key Issues --- | |
| # Ensure the structure matches KeyIssueResponse and KigKeyIssue Pydantic model | |
| generated_issues_data = final_state.get("key_issues", []) | |
| # Validate and convert if necessary (Pydantic usually handles this via response_model) | |
| try: | |
| # Pydantic will validate against KeyIssueResponse -> List[KigKeyIssue] | |
| response_data = {"key_issues": generated_issues_data} | |
| logger.info(f"Successfully generated {len(generated_issues_data)} key issues.") | |
| return response_data | |
| except Exception as pydantic_error: # Catch potential validation errors | |
| logger.error(f"Failed to validate final key issues against response model: {pydantic_error}", exc_info=True) | |
| logger.error(f"Data that failed validation: {generated_issues_data}") | |
| raise HTTPException(status_code=500, detail="Internal error: Failed to format key issues response.") | |
| except HTTPException as http_exc: | |
| # Re-raise HTTPExceptions directly | |
| raise http_exc | |
| except ConnectionError as e: | |
| logger.error(f"Connection Error during API request: {e}", exc_info=True) | |
| raise HTTPException(status_code=503, detail=f"Service Unavailable: {e}") | |
| except ValueError as e: | |
| logger.error(f"Value Error during API request: {e}", exc_info=True) | |
| raise HTTPException(status_code=400, detail=f"Bad Request: {e}") # Often input validation issues | |
| except Exception as e: | |
| logger.error(f"An unexpected error occurred during API request: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Internal Server Error: An unexpected error occurred.") | |
| # --- How to Run --- | |
| if __name__ == "__main__": | |
| # Make sure to set environment variables for config (NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY, etc.) | |
| # or have a .env file in the same directory where you run this script. | |
| print("Starting API server...") | |
| print("Ensure required environment variables (e.g., NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY) are set or .env file is present.") | |
| # Run with uvicorn: uvicorn api:app --reload --host 0.0.0.0 --port 8000 | |
| # The --reload flag is good for development. Remove it for production. | |
| uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) # Use reload=False for production |