Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, Depends | |
| from fastapi.responses import StreamingResponse | |
| from langchain_core.messages import AIMessageChunk | |
| from langchain_core.runnables import RunnableConfig | |
| from src.agents.agent_transcript.flow import script_writer_agent | |
| from src.utils.logger import logger | |
| from pydantic import BaseModel, Field | |
| import json | |
| import asyncio | |
| from src.apis.middlewares.auth_middleware import get_current_user | |
| from typing import Annotated | |
| from src.apis.models.user_models import User | |
| user_dependency = Annotated[User, Depends(get_current_user)] | |
| class GenScriptRequest(BaseModel): | |
| video_link: str = Field(..., description="Video link") | |
| target_word_count: int = Field( | |
| 2500, ge=2000, le=12000, description="Target word count" | |
| ) | |
| language: str = Field(..., description="Language") | |
| router = APIRouter() | |
| async def message_generator( | |
| input_graph: dict, | |
| config: RunnableConfig, | |
| ): | |
| # try: | |
| last_output_state = None | |
| # try: | |
| async for event in script_writer_agent.astream( | |
| input=input_graph, stream_mode=["messages", "values"], config=config | |
| ): | |
| # try: | |
| event_type, event_message = event | |
| logger.info(f"Event type: {event_type}") | |
| if event_type == "messages": | |
| message, metadata = event_message | |
| if isinstance(message, AIMessageChunk): | |
| # Stream AI message chunks | |
| node = metadata.get("node") | |
| chunk_data = { | |
| "type": "message_chunk", | |
| "content": message.content, | |
| "metadata": metadata, | |
| "node_step": node, | |
| } | |
| logger.info(f"Chunk data: {chunk_data}") | |
| yield f"data: {json.dumps(chunk_data)}\n\n" | |
| elif event_type == "values": | |
| # Stream state updates | |
| state_data = {"type": "state_update", "state": event_message} | |
| last_output_state = event_message | |
| # Handle specific data extractions | |
| if "transcript" in event_message and event_message["transcript"]: | |
| transcript_data = { | |
| "type": "transcript_extracted", | |
| "transcript": ( | |
| event_message["transcript"][:500] + "..." | |
| if len(event_message["transcript"]) > 500 | |
| else event_message["transcript"] | |
| ), | |
| "full_length": len(event_message["transcript"]), | |
| } | |
| yield f"data: {json.dumps(transcript_data)}\n\n" | |
| if "comment" in event_message and event_message["comment"]: | |
| comment_data = { | |
| "type": "comment_extracted", | |
| "comment": ( | |
| event_message["comment"][:500] + "..." | |
| if len(event_message["comment"]) > 500 | |
| else event_message["comment"] | |
| ), | |
| "full_length": len(event_message["comment"]), | |
| } | |
| yield f"data: {json.dumps(comment_data)}\n\n" | |
| if "script_count" in event_message: | |
| script_count_data = { | |
| "type": "script_count_calculated", | |
| "script_count": event_message["script_count"], | |
| "target_word_count": event_message.get("target_word_count", 8000), | |
| } | |
| yield f"data: {json.dumps(script_count_data)}\n\n" | |
| # Handle individual script updates | |
| if ( | |
| "script_writer_response" in event_message | |
| and "current_script_index" in event_message | |
| ): | |
| current_scripts = event_message["script_writer_response"] | |
| current_index = event_message["current_script_index"] | |
| script_count = event_message.get("script_count", 10) | |
| if current_scripts: | |
| individual_script_data = { | |
| "type": "individual_script", | |
| "script_index": current_index, | |
| "script_content": ( | |
| current_scripts[-1] if current_scripts else "" | |
| ), | |
| "progress": f"{current_index}/{script_count}", | |
| "scripts": current_scripts, | |
| } | |
| yield f"data: {json.dumps(individual_script_data)}\n\n" | |
| yield f"data: {json.dumps(state_data, default=str)}\n\n" | |
| # except Exception as e: | |
| # logger.error(f"Error processing event: {e}") | |
| # error_data = {"type": "error", "message": str(e)} | |
| # yield f"data: {json.dumps(error_data)}\n\n" | |
| # except Exception as e: | |
| # logger.error(f"Error in streaming: {e}") | |
| # error_data = {"type": "error", "message": str(e)} | |
| # yield f"data: {json.dumps(error_data)}\n\n" | |
| # Send final result | |
| if last_output_state: | |
| final_data = { | |
| "type": "final_result", | |
| "scripts": last_output_state.get("script_writer_response", []), | |
| "total_scripts": len(last_output_state.get("script_writer_response", [])), | |
| } | |
| yield f"data: {json.dumps(final_data, default=str)}\n\n" | |
| # except Exception as e: | |
| # logger.error(f"Fatal error in message_generator: {e}") | |
| # yield f"data: {json.dumps({'type': 'fatal_error', 'message': str(e)})}\n\n" | |
| async def gen_script(request: GenScriptRequest, user: user_dependency): | |
| """ | |
| Generate scripts with streaming response | |
| """ | |
| config = RunnableConfig() | |
| input_graph = { | |
| "video_link": request.video_link, | |
| "target_word_count": request.target_word_count, | |
| "language": request.language, | |
| } | |
| return StreamingResponse( | |
| message_generator(input_graph, config), | |
| media_type="text/plain", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Content-Type": "text/event-stream", | |
| }, | |
| ) | |