Spaces:
Sleeping
Sleeping
| import uuid | |
| import json | |
| import logging | |
| from typing import Dict, Annotated | |
| from fastapi import APIRouter, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import Response | |
| from app.agent.state import AgentState | |
| from app.agent.graph import start_graph, submit_graph | |
| from app.models.question import QuestionOut, Answer | |
| from app.models.session import SessionStartResponse, RoundSubmitResponse, ResultsResponse | |
| from app.services.transcription import transcription_service | |
| from app.services.file_generator import post_process_markdown | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/api/session", tags=["session"]) | |
| # In-memory session storage | |
| sessions: Dict[str, AgentState] = {} | |
| async def start_session(): | |
| session_id = str(uuid.uuid4())[:8] | |
| initial_state: AgentState = { | |
| "session_id": session_id, | |
| "current_round": 0, | |
| "max_rounds": 3, | |
| "current_questions": [], | |
| "all_answers": [], | |
| "round_summaries": [], | |
| "checklist_items": [], | |
| "markdown_content": "", | |
| "is_complete": False, | |
| } | |
| result = await start_graph.ainvoke(initial_state) | |
| sessions[session_id] = result | |
| return SessionStartResponse( | |
| session_id=session_id, | |
| round=result["current_round"], | |
| questions=result["current_questions"], | |
| ) | |
| async def transcribe_audio(audio_file: UploadFile = File(...)): | |
| audio_bytes = await audio_file.read() | |
| transcript = await transcription_service.transcribe(audio_bytes) | |
| return {"transcript": transcript} | |
| async def submit_round( | |
| session_id: str, | |
| audio_files: list[UploadFile] = File(...), | |
| question_ids: Annotated[str, Form()] = "", | |
| ): | |
| if session_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| state = sessions[session_id] | |
| if state["is_complete"]: | |
| raise HTTPException(status_code=400, detail="Session already complete") | |
| q_ids = [qid.strip() for qid in question_ids.split(",") if qid.strip()] | |
| current_questions = state["current_questions"] | |
| # Build question lookup | |
| q_map = {q.id: q.text for q in current_questions} | |
| # Transcribe each audio and create Answer objects | |
| new_answers = [] | |
| for i, audio_file in enumerate(audio_files): | |
| audio_bytes = await audio_file.read() | |
| transcript = await transcription_service.transcribe(audio_bytes) | |
| qid = q_ids[i] if i < len(q_ids) else current_questions[i].id if i < len(current_questions) else f"q{i}" | |
| q_text = q_map.get(qid, f"Question {i+1}") | |
| answer = Answer( | |
| question_id=qid, | |
| question_text=q_text, | |
| audio_transcript=transcript, | |
| round_number=state["current_round"], | |
| ) | |
| new_answers.append(answer) | |
| # Update state with new answers | |
| state["all_answers"] = list(state["all_answers"]) + new_answers | |
| # Run submit graph | |
| result = await submit_graph.ainvoke(state) | |
| sessions[session_id] = result | |
| if result["is_complete"]: | |
| # Post-process markdown | |
| result["markdown_content"] = post_process_markdown( | |
| result["markdown_content"], session_id | |
| ) | |
| sessions[session_id] = result | |
| return RoundSubmitResponse( | |
| round=result["current_round"], | |
| is_complete=True, | |
| round_summary=result["round_summaries"][-1], | |
| checklist_preview=result["markdown_content"], | |
| ) | |
| return RoundSubmitResponse( | |
| round=result["current_round"], | |
| questions=result["current_questions"], | |
| round_summary=result["round_summaries"][-1], | |
| is_complete=False, | |
| ) | |
| async def get_results(session_id: str): | |
| if session_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| state = sessions[session_id] | |
| if not state["is_complete"]: | |
| raise HTTPException(status_code=400, detail="Session not complete yet") | |
| return ResultsResponse( | |
| session_id=session_id, | |
| checklist=state["checklist_items"], | |
| markdown=state["markdown_content"], | |
| ) | |
| async def download_checklist(session_id: str): | |
| if session_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| state = sessions[session_id] | |
| if not state["is_complete"]: | |
| raise HTTPException(status_code=400, detail="Session not complete yet") | |
| return Response( | |
| content=state["markdown_content"].encode("utf-8"), | |
| media_type="text/markdown", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=checklist-{session_id}.md" | |
| }, | |
| ) | |