from fastapi import APIRouter from fastapi.responses import StreamingResponse from langchain_core.messages import AIMessageChunk from langchain_core.runnables import RunnableConfig from src.agents.agent_transcript.flow import script_writer_agent from src.utils.logger import logger from pydantic import BaseModel import json import asyncio class GenScriptRequest(BaseModel): video_link: str target_word_count: int = 50000 # Default 2500 words router = APIRouter() async def message_generator( input_graph: dict, config: RunnableConfig, ): try: last_output_state = None try: async for event in script_writer_agent.astream( input=input_graph, stream_mode=["messages", "values"], config=config ): try: event_type, event_message = event logger.info(f"Event type: {event_type}") if event_type == "messages": message, metadata = event_message if isinstance(message, AIMessageChunk): # Stream AI message chunks node = metadata.get("node") chunk_data = { "type": "message_chunk", "content": message.content, "metadata": metadata, "node_step": node, } logger.info(f"Chunk data: {chunk_data}") yield f"data: {json.dumps(chunk_data)}\n\n" elif event_type == "values": # Stream state updates state_data = {"type": "state_update", "state": event_message} last_output_state = event_message # Handle specific data extractions if "transcript" in event_message and event_message["transcript"]: transcript_data = { "type": "transcript_extracted", "transcript": event_message["transcript"][:500] + "..." if len(event_message["transcript"]) > 500 else event_message["transcript"], "full_length": len(event_message["transcript"]) } yield f"data: {json.dumps(transcript_data)}\n\n" if "comment" in event_message and event_message["comment"]: comment_data = { "type": "comment_extracted", "comment": event_message["comment"][:500] + "..." if len(event_message["comment"]) > 500 else event_message["comment"], "full_length": len(event_message["comment"]) } yield f"data: {json.dumps(comment_data)}\n\n" if "script_count" in event_message: script_count_data = { "type": "script_count_calculated", "script_count": event_message["script_count"], "target_word_count": event_message.get("target_word_count", 8000) } yield f"data: {json.dumps(script_count_data)}\n\n" # Handle individual script updates if "script_writer_response" in event_message and "current_script_index" in event_message: current_scripts = event_message["script_writer_response"] current_index = event_message["current_script_index"] script_count = event_message.get("script_count", 10) if current_scripts: individual_script_data = { "type": "individual_script", "script_index": current_index, "script_content": current_scripts[-1] if current_scripts else "", "progress": f"{current_index}/{script_count}", "scripts": current_scripts } yield f"data: {json.dumps(individual_script_data)}\n\n" yield f"data: {json.dumps(state_data, default=str)}\n\n" except Exception as e: logger.error(f"Error processing event: {e}") error_data = {"type": "error", "message": str(e)} yield f"data: {json.dumps(error_data)}\n\n" except Exception as e: logger.error(f"Error in streaming: {e}") error_data = {"type": "error", "message": str(e)} yield f"data: {json.dumps(error_data)}\n\n" # Send final result if last_output_state: final_data = { "type": "final_result", "scripts": last_output_state.get("script_writer_response", []), "total_scripts": len( last_output_state.get("script_writer_response", []) ), } yield f"data: {json.dumps(final_data, default=str)}\n\n" except Exception as e: logger.error(f"Fatal error in message_generator: {e}") yield f"data: {json.dumps({'type': 'fatal_error', 'message': str(e)})}\n\n" @router.post("/gen-script") async def gen_script(request: GenScriptRequest): """ Generate scripts with streaming response """ config = RunnableConfig() input_graph = { "video_link": request.video_link, "target_word_count": request.target_word_count } return StreamingResponse( message_generator(input_graph, config), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/event-stream", }, ) @router.post("/gen-script-sync") def gen_script_sync(request: GenScriptRequest): """ Generate scripts with synchronous response (non-streaming) """ response = script_writer_agent.invoke({ "video_link": request.video_link, "target_word_count": request.target_word_count }) return { "scripts": response.get("script_writer_response", []), "total_scripts": len(response.get("script_writer_response", [])), "full_response": response, }