Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter | |
| from fastapi.responses import StreamingResponse | |
| from langchain_core.messages import AIMessageChunk | |
| from langchain_core.runnables import RunnableConfig | |
| from src.agents.agent_transcript.flow import script_writer_agent | |
| from src.utils.logger import logger | |
| from pydantic import BaseModel | |
| import json | |
| import asyncio | |
| class GenScriptRequest(BaseModel): | |
| video_link: str | |
| target_word_count: int = 50000 # Default 2500 words | |
| router = APIRouter() | |
| async def message_generator( | |
| input_graph: dict, | |
| config: RunnableConfig, | |
| ): | |
| try: | |
| last_output_state = None | |
| try: | |
| async for event in script_writer_agent.astream( | |
| input=input_graph, stream_mode=["messages", "values"], config=config | |
| ): | |
| try: | |
| event_type, event_message = event | |
| logger.info(f"Event type: {event_type}") | |
| if event_type == "messages": | |
| message, metadata = event_message | |
| if isinstance(message, AIMessageChunk): | |
| # Stream AI message chunks | |
| node = metadata.get("node") | |
| chunk_data = { | |
| "type": "message_chunk", | |
| "content": message.content, | |
| "metadata": metadata, | |
| "node_step": node, | |
| } | |
| logger.info(f"Chunk data: {chunk_data}") | |
| yield f"data: {json.dumps(chunk_data)}\n\n" | |
| elif event_type == "values": | |
| # Stream state updates | |
| state_data = {"type": "state_update", "state": event_message} | |
| last_output_state = event_message | |
| # Handle specific data extractions | |
| if "transcript" in event_message and event_message["transcript"]: | |
| transcript_data = { | |
| "type": "transcript_extracted", | |
| "transcript": event_message["transcript"][:500] + "..." if len(event_message["transcript"]) > 500 else event_message["transcript"], | |
| "full_length": len(event_message["transcript"]) | |
| } | |
| yield f"data: {json.dumps(transcript_data)}\n\n" | |
| if "comment" in event_message and event_message["comment"]: | |
| comment_data = { | |
| "type": "comment_extracted", | |
| "comment": event_message["comment"][:500] + "..." if len(event_message["comment"]) > 500 else event_message["comment"], | |
| "full_length": len(event_message["comment"]) | |
| } | |
| yield f"data: {json.dumps(comment_data)}\n\n" | |
| if "script_count" in event_message: | |
| script_count_data = { | |
| "type": "script_count_calculated", | |
| "script_count": event_message["script_count"], | |
| "target_word_count": event_message.get("target_word_count", 8000) | |
| } | |
| yield f"data: {json.dumps(script_count_data)}\n\n" | |
| # Handle individual script updates | |
| if "script_writer_response" in event_message and "current_script_index" in event_message: | |
| current_scripts = event_message["script_writer_response"] | |
| current_index = event_message["current_script_index"] | |
| script_count = event_message.get("script_count", 10) | |
| if current_scripts: | |
| individual_script_data = { | |
| "type": "individual_script", | |
| "script_index": current_index, | |
| "script_content": current_scripts[-1] if current_scripts else "", | |
| "progress": f"{current_index}/{script_count}", | |
| "scripts": current_scripts | |
| } | |
| yield f"data: {json.dumps(individual_script_data)}\n\n" | |
| yield f"data: {json.dumps(state_data, default=str)}\n\n" | |
| except Exception as e: | |
| logger.error(f"Error processing event: {e}") | |
| error_data = {"type": "error", "message": str(e)} | |
| yield f"data: {json.dumps(error_data)}\n\n" | |
| except Exception as e: | |
| logger.error(f"Error in streaming: {e}") | |
| error_data = {"type": "error", "message": str(e)} | |
| yield f"data: {json.dumps(error_data)}\n\n" | |
| # Send final result | |
| if last_output_state: | |
| final_data = { | |
| "type": "final_result", | |
| "scripts": last_output_state.get("script_writer_response", []), | |
| "total_scripts": len( | |
| last_output_state.get("script_writer_response", []) | |
| ), | |
| } | |
| yield f"data: {json.dumps(final_data, default=str)}\n\n" | |
| except Exception as e: | |
| logger.error(f"Fatal error in message_generator: {e}") | |
| yield f"data: {json.dumps({'type': 'fatal_error', 'message': str(e)})}\n\n" | |
| async def gen_script(request: GenScriptRequest): | |
| """ | |
| Generate scripts with streaming response | |
| """ | |
| config = RunnableConfig() | |
| input_graph = { | |
| "video_link": request.video_link, | |
| "target_word_count": request.target_word_count | |
| } | |
| return StreamingResponse( | |
| message_generator(input_graph, config), | |
| media_type="text/plain", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Content-Type": "text/event-stream", | |
| }, | |
| ) | |
| def gen_script_sync(request: GenScriptRequest): | |
| """ | |
| Generate scripts with synchronous response (non-streaming) | |
| """ | |
| response = script_writer_agent.invoke({ | |
| "video_link": request.video_link, | |
| "target_word_count": request.target_word_count | |
| }) | |
| return { | |
| "scripts": response.get("script_writer_response", []), | |
| "total_scripts": len(response.get("script_writer_response", [])), | |
| "full_response": response, | |
| } | |