from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse from langchain_core.messages import AIMessageChunk from langchain_core.runnables import RunnableConfig from src.agents.agent_transcript.flow import script_writer_agent from src.utils.logger import logger from pydantic import BaseModel, Field import json import asyncio from src.apis.middlewares.auth_middleware import get_current_user from typing import Annotated from src.apis.models.user_models import User user_dependency = Annotated[User, Depends(get_current_user)] class GenScriptRequest(BaseModel): video_link: str = Field(..., description="Video link") target_word_count: int = Field( 2500, ge=2000, le=12000, description="Target word count" ) language: str = Field(..., description="Language") router = APIRouter() async def message_generator( input_graph: dict, config: RunnableConfig, ): # try: last_output_state = None # try: async for event in script_writer_agent.astream( input=input_graph, stream_mode=["messages", "values"], config=config ): # try: event_type, event_message = event logger.info(f"Event type: {event_type}") if event_type == "messages": message, metadata = event_message if isinstance(message, AIMessageChunk): # Stream AI message chunks node = metadata.get("node") chunk_data = { "type": "message_chunk", "content": message.content, "metadata": metadata, "node_step": node, } logger.info(f"Chunk data: {chunk_data}") yield f"data: {json.dumps(chunk_data)}\n\n" elif event_type == "values": # Stream state updates state_data = {"type": "state_update", "state": event_message} last_output_state = event_message # Handle specific data extractions if "transcript" in event_message and event_message["transcript"]: transcript_data = { "type": "transcript_extracted", "transcript": ( event_message["transcript"][:500] + "..." if len(event_message["transcript"]) > 500 else event_message["transcript"] ), "full_length": len(event_message["transcript"]), } yield f"data: {json.dumps(transcript_data)}\n\n" if "comment" in event_message and event_message["comment"]: comment_data = { "type": "comment_extracted", "comment": ( event_message["comment"][:500] + "..." if len(event_message["comment"]) > 500 else event_message["comment"] ), "full_length": len(event_message["comment"]), } yield f"data: {json.dumps(comment_data)}\n\n" if "script_count" in event_message: script_count_data = { "type": "script_count_calculated", "script_count": event_message["script_count"], "target_word_count": event_message.get("target_word_count", 8000), } yield f"data: {json.dumps(script_count_data)}\n\n" # Handle individual script updates if ( "script_writer_response" in event_message and "current_script_index" in event_message ): current_scripts = event_message["script_writer_response"] current_index = event_message["current_script_index"] script_count = event_message.get("script_count", 10) if current_scripts: individual_script_data = { "type": "individual_script", "script_index": current_index, "script_content": ( current_scripts[-1] if current_scripts else "" ), "progress": f"{current_index}/{script_count}", "scripts": current_scripts, } yield f"data: {json.dumps(individual_script_data)}\n\n" yield f"data: {json.dumps(state_data, default=str)}\n\n" # except Exception as e: # logger.error(f"Error processing event: {e}") # error_data = {"type": "error", "message": str(e)} # yield f"data: {json.dumps(error_data)}\n\n" # except Exception as e: # logger.error(f"Error in streaming: {e}") # error_data = {"type": "error", "message": str(e)} # yield f"data: {json.dumps(error_data)}\n\n" # Send final result if last_output_state: final_data = { "type": "final_result", "scripts": last_output_state.get("script_writer_response", []), "total_scripts": len(last_output_state.get("script_writer_response", [])), } yield f"data: {json.dumps(final_data, default=str)}\n\n" # except Exception as e: # logger.error(f"Fatal error in message_generator: {e}") # yield f"data: {json.dumps({'type': 'fatal_error', 'message': str(e)})}\n\n" @router.post("/gen-script") async def gen_script(request: GenScriptRequest, user: user_dependency): """ Generate scripts with streaming response """ config = RunnableConfig() input_graph = { "video_link": request.video_link, "target_word_count": request.target_word_count, "language": request.language, } return StreamingResponse( message_generator(input_graph, config), media_type="text/plain", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "text/event-stream", }, )