| import uuid |
| import asyncio |
| import json |
| from fastapi import APIRouter |
| from pydantic import BaseModel |
| from sse_starlette.sse import EventSourceResponse |
|
|
| from app.agents.graph import build_graph |
| from app.logger import setup_logger |
| from app.utils.report_formatter import ( |
| parse_markdown_report, |
| format_report_for_display, |
| extract_citations, |
| get_report_stats |
| ) |
|
|
| logger = setup_logger(__name__) |
|
|
| router = APIRouter() |
| graph = build_graph() |
|
|
|
|
| class RunRequest(BaseModel): |
| query: str |
|
|
|
|
| @router.post("/run") |
| async def run_agent(req: RunRequest): |
| thread_id = str(uuid.uuid4()) |
| logger.info(f"π₯ POST /run - Query: {req.query[:100]}... [Thread ID: {thread_id}]") |
|
|
| state = { |
| "query": req.query, |
| "research": [], |
| "critic_feedback": None, |
| "final_report": None, |
| "iteration": 0, |
| "logs": [], |
| } |
|
|
| try: |
| logger.debug(f"π Invoking graph for thread {thread_id}") |
| result = await asyncio.to_thread( |
| graph.invoke, |
| state, |
| { |
| "configurable": { |
| "thread_id": thread_id |
| } |
| } |
| ) |
| logger.info(f"β
Graph execution completed for thread {thread_id}") |
| logger.debug(f"π Final report length: {len(result.get('final_report', '')) if result.get('final_report') else 0} chars") |
| |
| |
| final_report = result.get("final_report", "") |
| formatted_report = format_report_for_display(final_report) |
| parsed_report = parse_markdown_report(formatted_report) |
| report_stats = get_report_stats(formatted_report) |
| |
| return { |
| "thread_id": thread_id, |
| "final_report": formatted_report, |
| "parsed_report": parsed_report, |
| "report_stats": report_stats, |
| "logs": result["logs"], |
| } |
| except Exception as e: |
| logger.error(f"β Error in /run endpoint for thread {thread_id}: {e}", exc_info=True) |
| raise |
|
|
|
|
| @router.get("/stream") |
| async def stream_agent(query: str): |
| async def event_generator(): |
| thread_id = str(uuid.uuid4()) |
| logger.info(f"π΄ GET /stream - Query: {query[:100]}... [Thread ID: {thread_id}]") |
|
|
| state = { |
| "query": query, |
| "research": [], |
| "critic_feedback": None, |
| "final_report": None, |
| "iteration": 0, |
| "logs": [], |
| } |
|
|
| try: |
| event_count = 0 |
| logger.debug(f"π Starting stream for thread {thread_id}") |
| for event in graph.stream( |
| state, |
| config={ |
| "configurable": { |
| "thread_id": thread_id |
| } |
| } |
| ): |
| event_count += 1 |
| logger.debug(f"π€ Streaming event {event_count}: {list(event.keys())}") |
| |
| |
| formatted_event = event.copy() |
| if 'writer' in event: |
| writer_data = event['writer'].copy() |
| if 'final_report' in writer_data: |
| final_report = writer_data['final_report'] |
| writer_data['final_report_formatted'] = format_report_for_display(final_report) |
| writer_data['report_stats'] = get_report_stats(final_report) |
| formatted_event['writer'] = writer_data |
| |
| yield { |
| "event": "update", |
| "data": json.dumps(formatted_event), |
| } |
| logger.info(f"β
Stream completed for thread {thread_id}. Total events: {event_count}") |
| except Exception as e: |
| logger.error(f"β Error in stream for thread {thread_id}: {e}", exc_info=True) |
| yield { |
| "event": "error", |
| "data": json.dumps({"error": str(e)}), |
| } |
|
|
| return EventSourceResponse(event_generator()) |
|
|