Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, APIRouter, UploadFile, File, HTTPException, Query, Form | |
| from fastapi.responses import FileResponse | |
| from typing import Optional | |
| from contextlib import asynccontextmanager | |
| import os | |
| import shutil | |
| import logging | |
| import json | |
| from agents.simple_tools import generate_notes_full_pipeline_from_path | |
| from agents.generator_validator import create_notes_pipeline, InteractiveFeedbackManager | |
| from agents.langgraph import run_workflow | |
| from agents.rlhf_workflows import run_rlhf_workflow | |
| from agents.rlhf_routes import rlhf_router | |
| from fastapi.middleware.cors import CORSMiddleware | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("financial_notes_api") | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| logger.info("Financial Notes Generator API has started.") | |
| yield | |
| # Shutdown | |
| logger.info("Financial Notes Generator API is shutting down.") | |
| # Initialize FastAPI app first | |
| app = FastAPI( | |
| title="Financial Notes Generator API", | |
| description="API for generating financial notes, balance sheets, cash flow statements, and P&L reports with RLHF capabilities and Interactive Feedback.", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware immediately after app initialization | |
| # Using "*" for debugging - restrict this in production | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Temporarily allow all origins for debugging | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allow all methods (GET, POST, OPTIONS, etc.) | |
| allow_headers=["*"], # Allow all headers | |
| expose_headers=["*"], # Expose all custom headers to frontend | |
| ) | |
| # Initialize feedback manager | |
| feedback_manager = InteractiveFeedbackManager() | |
| # Include RLHF router | |
| app.include_router(rlhf_router) | |
| # Initialize router for main endpoints | |
| router = APIRouter() | |
| async def root(): | |
| """ | |
| Root endpoint for the Financial Notes Generator API. | |
| Returns basic API information. | |
| """ | |
| return { | |
| "message": "Welcome to Financial Notes Generator API", | |
| "version": "1.0.0", | |
| "description": "API for generating financial notes, balance sheets, cash flow statements, and P&L reports", | |
| "endpoints": { | |
| "notes": "POST /notes - Generate financial notes from trial balance", | |
| "notes-llm": "POST /notes-llm - Generate LLM-based notes with interactive feedback", | |
| "bs": "POST /bs - Generate balance sheet", | |
| "pnl": "POST /pnl - Generate P&L statement", | |
| "cf": "POST /cf - Generate cash flow statement", | |
| "docs": "/docs - API documentation" | |
| } | |
| } | |
| async def notes_llm_route( | |
| file: UploadFile = File(...), | |
| use_rlhf: bool = Query(False), | |
| user_api_key: Optional[str] = Form(None) | |
| ): | |
| if not user_api_key or user_api_key.strip() == "": | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Missing required parameter: 'user_api_key'. Please provide your OpenRouter API key as a form parameter (not in JSON body)." | |
| ) | |
| file_path = f"data/input/{file.filename}" | |
| os.makedirs("data/input", exist_ok=True) | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| try: | |
| pipeline = create_notes_pipeline(use_rlhf=use_rlhf, user_api_key=user_api_key) | |
| generation_result, validation_result = pipeline.process(file_path) | |
| summary = pipeline.get_processing_summary() | |
| logger.info(f"LLM Notes Pipeline Summary: {summary}") | |
| if generation_result.success and validation_result.is_valid: | |
| session_id = feedback_manager.create_session(file_path) | |
| response = FileResponse( | |
| generation_result.output_path, | |
| filename=os.path.basename(generation_result.output_path) | |
| ) | |
| response.headers["X-Generation-Method"] = "llm" | |
| response.headers["X-Validation-Score"] = str(validation_result.score) | |
| response.headers["X-Attempts-Made"] = str(generation_result.metadata.get("attempt", 1)) | |
| response.headers["X-Execution-ID"] = generation_result.metadata.get("execution_id", "") | |
| response.headers["X-Session-ID"] = session_id | |
| response.headers["X-Interactive-Enabled"] = "true" | |
| if use_rlhf and "rlhf_metadata" in generation_result.metadata: | |
| rlhf_data = generation_result.metadata["rlhf_metadata"] | |
| response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", "")) | |
| response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", "")) | |
| response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", "")) | |
| if validation_result.feedback: | |
| response.headers["X-Validation-Feedback"] = json.dumps(validation_result.feedback) | |
| return response | |
| else: | |
| error_detail = { | |
| "generation_error": generation_result.error, | |
| "validation_feedback": validation_result.feedback, | |
| "validation_score": validation_result.score, | |
| "attempts_made": generation_result.metadata.get("attempt", 1), | |
| "processing_summary": summary | |
| } | |
| raise HTTPException(status_code=500, detail=json.dumps(error_detail)) | |
| except ValueError as ve: | |
| logger.error(f"API key error: {ve}") | |
| if "API key is required" in str(ve): | |
| raise HTTPException(status_code=400, detail="Missing OpenRouter API key. Please provide your API key via the 'user_api_key' form parameter.") | |
| raise HTTPException(status_code=400, detail=str(ve)) | |
| except Exception as e: | |
| logger.error(f"LLM Notes pipeline failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Pipeline processing failed: {str(e)}") | |
| async def submit_feedback( | |
| session_id: str = Form(...), | |
| feedback_text: str = Form(...), | |
| feedback_type: str = Form(..., pattern="^(text|numeric|formula|suggestion)$") | |
| ): | |
| try: | |
| udf_version = feedback_manager.add_feedback(session_id, feedback_text, feedback_type) | |
| if udf_version is None: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "udf_version": udf_version, | |
| "iteration": feedback_manager.get_session(session_id).current_iteration, | |
| "message": "Feedback submitted and UDF generated successfully" | |
| } | |
| except Exception as e: | |
| logger.error(f"Feedback submission failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Feedback submission failed: {str(e)}") | |
| async def approve_session(session_id: str = Form(...)): | |
| try: | |
| success = feedback_manager.approve_session(session_id) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| session = feedback_manager.get_session(session_id) | |
| return { | |
| "status": "approved", | |
| "session_id": session_id, | |
| "final_udf": session.final_udf, | |
| "total_iterations": session.current_iteration, | |
| "archived_udfs_count": len(session.archived_udfs), | |
| "message": "Session approved and final UDF set" | |
| } | |
| except Exception as e: | |
| logger.error(f"Session approval failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Session approval failed: {str(e)}") | |
| async def get_session_info(session_id: str): | |
| try: | |
| session = feedback_manager.get_session(session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return { | |
| "session_id": session.session_id, | |
| "status": session.status, | |
| "current_iteration": session.current_iteration, | |
| "total_feedbacks": len(session.feedback_history), | |
| "archived_udfs_count": len(session.archived_udfs), | |
| "final_udf": session.final_udf, | |
| "created_at": session.created_at.isoformat(), | |
| "last_updated": session.last_updated.isoformat(), | |
| "feedback_history": [ | |
| { | |
| "iteration": f.iteration_number, | |
| "feedback_type": f.feedback_type, | |
| "feedback_text": f.feedback_text, | |
| "udf_version": f.udf_version, | |
| "timestamp": f.timestamp.isoformat(), | |
| "changes_description": f.changes_description | |
| } for f in session.feedback_history | |
| ] | |
| } | |
| except Exception as e: | |
| logger.error(f"Session info retrieval failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Session info retrieval failed: {str(e)}") | |
| async def generate_with_feedback( | |
| session_id: str = Form(...), | |
| file: UploadFile = File(...), | |
| user_api_key: Optional[str] = Form(None) | |
| ): | |
| if not user_api_key or user_api_key.strip() == "": | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Missing required parameter: 'user_api_key'. Please provide your OpenRouter API key as a form parameter (not in JSON body)." | |
| ) | |
| try: | |
| session = feedback_manager.get_session(session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| if session.status != 'active': | |
| raise HTTPException(status_code=400, detail=f"Session is {session.status}") | |
| file_path = f"data/input/{file.filename}" | |
| os.makedirs("data/input", exist_ok=True) | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| pipeline = create_notes_pipeline(use_rlhf=False, user_api_key=user_api_key) | |
| udfs_to_apply = [] | |
| if session.final_udf: | |
| udfs_to_apply.append(session.final_udf) | |
| elif session.archived_udfs: | |
| udfs_to_apply.extend(session.archived_udfs) | |
| feedback_context = { | |
| 'session_id': session_id, | |
| 'udfs': udfs_to_apply, | |
| 'feedback_history': [ | |
| { | |
| 'text': f.feedback_text, | |
| 'type': f.feedback_type, | |
| 'iteration': f.iteration_number | |
| } for f in session.feedback_history | |
| ], | |
| 'current_iteration': session.current_iteration | |
| } | |
| generation_result, validation_result = pipeline.process(file_path, feedback_context=feedback_context) | |
| if generation_result.success and validation_result.is_valid: | |
| response = FileResponse( | |
| generation_result.output_path, | |
| filename=os.path.basename(generation_result.output_path) | |
| ) | |
| response.headers["X-Session-ID"] = session_id | |
| response.headers["X-Iteration"] = str(session.current_iteration) | |
| response.headers["X-Feedbacks-Applied"] = str(len(session.feedback_history)) | |
| response.headers["X-UDFs-Archived"] = str(len(session.archived_udfs)) | |
| response.headers["X-Generation-Method"] = "llm_with_feedback" | |
| response.headers["X-Validation-Score"] = str(validation_result.score) | |
| response.headers["X-Execution-ID"] = generation_result.metadata.get("execution_id", "") | |
| return response | |
| else: | |
| error_detail = { | |
| "generation_error": generation_result.error, | |
| "validation_feedback": validation_result.feedback, | |
| "validation_score": validation_result.score, | |
| "session_id": session_id, | |
| "current_iteration": session.current_iteration | |
| } | |
| raise HTTPException(status_code=500, detail=json.dumps(error_detail)) | |
| except HTTPException: | |
| raise | |
| except ValueError as ve: | |
| logger.error(f"API key error: {ve}") | |
| raise HTTPException(status_code=400, detail=str(ve)) | |
| except Exception as e: | |
| logger.error(f"Feedback-based generation failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| async def notes_route(file: UploadFile = File(...)): | |
| try: | |
| file_path = f"data/input/{file.filename}" | |
| os.makedirs("data/input", exist_ok=True) | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| result = generate_notes_full_pipeline_from_path(file_path) | |
| if result["status"] == "success": | |
| output_path = result["output_xlsx_path"] | |
| return FileResponse(output_path, filename=os.path.basename(output_path)) | |
| raise HTTPException(status_code=500, detail=result.get("error", "Notes generation failed")) | |
| except Exception as e: | |
| logger.error(f"Error in notes generation: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error generating notes: {str(e)}") | |
| async def pnl_route(file: UploadFile = File(...), use_rlhf: bool = Query(False)): | |
| file_path = f"data/input/{file.filename}" | |
| os.makedirs("data/input", exist_ok=True) | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| if use_rlhf: | |
| result = run_rlhf_workflow(file_path, "pnl") | |
| else: | |
| result = run_workflow(file_path, "pnl") | |
| if result["status"] == "success": | |
| response = FileResponse( | |
| result["result"].get("output_path", "data/pnl_statement.xlsx"), | |
| filename=os.path.basename(result["result"].get("output_path", "data/pnl_statement.xlsx")) | |
| ) | |
| if "rlhf_metadata" in result.get("result", {}): | |
| rlhf_data = result["result"]["rlhf_metadata"] | |
| response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", "")) | |
| response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", "")) | |
| response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", "")) | |
| return response | |
| raise HTTPException(status_code=500, detail=result["error"]) | |
| async def bs_route(file: UploadFile = File(...), use_rlhf: bool = Query(False), user_api_key: Optional[str] = Form(None)): | |
| if not user_api_key or user_api_key.strip() == "": | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Missing required parameter: 'user_api_key'. Please provide your OpenRouter API key as a form parameter (not in JSON body)." | |
| ) | |
| file_path = f"data/input/{file.filename}" | |
| os.makedirs("data/input", exist_ok=True) | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| if use_rlhf: | |
| result = run_rlhf_workflow(file_path, "bs", user_api_key=user_api_key) | |
| else: | |
| result = run_workflow(file_path, "bs", user_api_key=user_api_key) | |
| if result["status"] == "success": | |
| output_file = result["result"].get("output_path") | |
| if not output_file or not os.path.isfile(output_file): | |
| output_dir = "data/output/" | |
| xlsx_files = [f for f in os.listdir(output_dir) if f.endswith('.xlsx') and os.path.isfile(os.path.join(output_dir, f))] | |
| if xlsx_files: | |
| output_file = os.path.join(output_dir, xlsx_files[0]) | |
| else: | |
| raise HTTPException(status_code=500, detail="No balance sheet Excel file produced") | |
| response = FileResponse(output_file, filename=os.path.basename(output_file)) | |
| if "rlhf_metadata" in result.get("result", {}): | |
| rlhf_data = result["result"]["rlhf_metadata"] | |
| response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", "")) | |
| response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", "")) | |
| response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", "")) | |
| return response | |
| else: | |
| error_msg = result.get("error", "Unknown error") | |
| # Check if error is about missing API key | |
| if "Missing OpenRouter API key" in error_msg: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Missing OpenRouter API key. Please provide your API key via the 'user_api_key' form parameter." | |
| ) | |
| raise HTTPException(status_code=500, detail=error_msg) | |
| async def cf_route(file: UploadFile = File(...), use_rlhf: bool = Query(False)): | |
| file_path = f"data/input/{file.filename}" | |
| os.makedirs("data/input", exist_ok=True) | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| if use_rlhf: | |
| result = run_rlhf_workflow(file_path, "cf") | |
| else: | |
| result = run_workflow(file_path, "cf") | |
| if result["status"] == "success": | |
| response = FileResponse( | |
| result["result"].get("output_path", "data/cash_flow_statements.xlsx"), | |
| filename=os.path.basename(result["result"].get("output_path", "data/cash_flow_statements.xlsx")) | |
| ) | |
| if "rlhf_metadata" in result.get("result", {}): | |
| rlhf_data = result["result"]["rlhf_metadata"] | |
| response.headers["X-RLHF-Statement-ID"] = str(rlhf_data.get("statement_id", "")) | |
| response.headers["X-RLHF-Quality-Score"] = str(rlhf_data.get("predicted_quality", "")) | |
| response.headers["X-RLHF-Confidence"] = str(rlhf_data.get("confidence_score", "")) | |
| return response | |
| raise HTTPException(status_code=500, detail=result["error"]) | |
| # Include router after all route definitions | |
| app.include_router(router) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |