alexorlov's picture
Upload app/routers/session.py with huggingface_hub
db68bd8 verified
import uuid
import json
import logging
from typing import Dict, Annotated
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
from fastapi.responses import Response
from app.agent.state import AgentState
from app.agent.graph import start_graph, submit_graph
from app.models.question import QuestionOut, Answer
from app.models.session import SessionStartResponse, RoundSubmitResponse, ResultsResponse
from app.services.transcription import transcription_service
from app.services.file_generator import post_process_markdown
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/session", tags=["session"])
# In-memory session storage
sessions: Dict[str, AgentState] = {}
@router.post("/start", response_model=SessionStartResponse)
async def start_session():
session_id = str(uuid.uuid4())[:8]
initial_state: AgentState = {
"session_id": session_id,
"current_round": 0,
"max_rounds": 3,
"current_questions": [],
"all_answers": [],
"round_summaries": [],
"checklist_items": [],
"markdown_content": "",
"is_complete": False,
}
result = await start_graph.ainvoke(initial_state)
sessions[session_id] = result
return SessionStartResponse(
session_id=session_id,
round=result["current_round"],
questions=result["current_questions"],
)
@router.post("/transcribe")
async def transcribe_audio(audio_file: UploadFile = File(...)):
audio_bytes = await audio_file.read()
transcript = await transcription_service.transcribe(audio_bytes)
return {"transcript": transcript}
@router.post("/{session_id}/submit", response_model=RoundSubmitResponse)
async def submit_round(
session_id: str,
audio_files: list[UploadFile] = File(...),
question_ids: Annotated[str, Form()] = "",
):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
state = sessions[session_id]
if state["is_complete"]:
raise HTTPException(status_code=400, detail="Session already complete")
q_ids = [qid.strip() for qid in question_ids.split(",") if qid.strip()]
current_questions = state["current_questions"]
# Build question lookup
q_map = {q.id: q.text for q in current_questions}
# Transcribe each audio and create Answer objects
new_answers = []
for i, audio_file in enumerate(audio_files):
audio_bytes = await audio_file.read()
transcript = await transcription_service.transcribe(audio_bytes)
qid = q_ids[i] if i < len(q_ids) else current_questions[i].id if i < len(current_questions) else f"q{i}"
q_text = q_map.get(qid, f"Question {i+1}")
answer = Answer(
question_id=qid,
question_text=q_text,
audio_transcript=transcript,
round_number=state["current_round"],
)
new_answers.append(answer)
# Update state with new answers
state["all_answers"] = list(state["all_answers"]) + new_answers
# Run submit graph
result = await submit_graph.ainvoke(state)
sessions[session_id] = result
if result["is_complete"]:
# Post-process markdown
result["markdown_content"] = post_process_markdown(
result["markdown_content"], session_id
)
sessions[session_id] = result
return RoundSubmitResponse(
round=result["current_round"],
is_complete=True,
round_summary=result["round_summaries"][-1],
checklist_preview=result["markdown_content"],
)
return RoundSubmitResponse(
round=result["current_round"],
questions=result["current_questions"],
round_summary=result["round_summaries"][-1],
is_complete=False,
)
@router.get("/{session_id}/results", response_model=ResultsResponse)
async def get_results(session_id: str):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
state = sessions[session_id]
if not state["is_complete"]:
raise HTTPException(status_code=400, detail="Session not complete yet")
return ResultsResponse(
session_id=session_id,
checklist=state["checklist_items"],
markdown=state["markdown_content"],
)
@router.get("/{session_id}/download")
async def download_checklist(session_id: str):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
state = sessions[session_id]
if not state["is_complete"]:
raise HTTPException(status_code=400, detail="Session not complete yet")
return Response(
content=state["markdown_content"].encode("utf-8"),
media_type="text/markdown",
headers={
"Content-Disposition": f"attachment; filename=checklist-{session_id}.md"
},
)