AutoGenScript / src /apis /routers /gen_script.py
ABAO77's picture
Upload 67 files
39cdf57 verified
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables import RunnableConfig
from src.agents.agent_transcript.flow import script_writer_agent
from src.utils.logger import logger
from pydantic import BaseModel, Field
import json
import asyncio
from src.apis.middlewares.auth_middleware import get_current_user
from typing import Annotated
from src.apis.models.user_models import User
user_dependency = Annotated[User, Depends(get_current_user)]
class GenScriptRequest(BaseModel):
video_link: str = Field(..., description="Video link")
target_word_count: int = Field(
2500, ge=2000, le=12000, description="Target word count"
)
language: str = Field(..., description="Language")
router = APIRouter()
async def message_generator(
input_graph: dict,
config: RunnableConfig,
):
# try:
last_output_state = None
# try:
async for event in script_writer_agent.astream(
input=input_graph, stream_mode=["messages", "values"], config=config
):
# try:
event_type, event_message = event
logger.info(f"Event type: {event_type}")
if event_type == "messages":
message, metadata = event_message
if isinstance(message, AIMessageChunk):
# Stream AI message chunks
node = metadata.get("node")
chunk_data = {
"type": "message_chunk",
"content": message.content,
"metadata": metadata,
"node_step": node,
}
logger.info(f"Chunk data: {chunk_data}")
yield f"data: {json.dumps(chunk_data)}\n\n"
elif event_type == "values":
# Stream state updates
state_data = {"type": "state_update", "state": event_message}
last_output_state = event_message
# Handle specific data extractions
if "transcript" in event_message and event_message["transcript"]:
transcript_data = {
"type": "transcript_extracted",
"transcript": (
event_message["transcript"][:500] + "..."
if len(event_message["transcript"]) > 500
else event_message["transcript"]
),
"full_length": len(event_message["transcript"]),
}
yield f"data: {json.dumps(transcript_data)}\n\n"
if "comment" in event_message and event_message["comment"]:
comment_data = {
"type": "comment_extracted",
"comment": (
event_message["comment"][:500] + "..."
if len(event_message["comment"]) > 500
else event_message["comment"]
),
"full_length": len(event_message["comment"]),
}
yield f"data: {json.dumps(comment_data)}\n\n"
if "script_count" in event_message:
script_count_data = {
"type": "script_count_calculated",
"script_count": event_message["script_count"],
"target_word_count": event_message.get("target_word_count", 8000),
}
yield f"data: {json.dumps(script_count_data)}\n\n"
# Handle individual script updates
if (
"script_writer_response" in event_message
and "current_script_index" in event_message
):
current_scripts = event_message["script_writer_response"]
current_index = event_message["current_script_index"]
script_count = event_message.get("script_count", 10)
if current_scripts:
individual_script_data = {
"type": "individual_script",
"script_index": current_index,
"script_content": (
current_scripts[-1] if current_scripts else ""
),
"progress": f"{current_index}/{script_count}",
"scripts": current_scripts,
}
yield f"data: {json.dumps(individual_script_data)}\n\n"
yield f"data: {json.dumps(state_data, default=str)}\n\n"
# except Exception as e:
# logger.error(f"Error processing event: {e}")
# error_data = {"type": "error", "message": str(e)}
# yield f"data: {json.dumps(error_data)}\n\n"
# except Exception as e:
# logger.error(f"Error in streaming: {e}")
# error_data = {"type": "error", "message": str(e)}
# yield f"data: {json.dumps(error_data)}\n\n"
# Send final result
if last_output_state:
final_data = {
"type": "final_result",
"scripts": last_output_state.get("script_writer_response", []),
"total_scripts": len(last_output_state.get("script_writer_response", [])),
}
yield f"data: {json.dumps(final_data, default=str)}\n\n"
# except Exception as e:
# logger.error(f"Fatal error in message_generator: {e}")
# yield f"data: {json.dumps({'type': 'fatal_error', 'message': str(e)})}\n\n"
@router.post("/gen-script")
async def gen_script(request: GenScriptRequest, user: user_dependency):
"""
Generate scripts with streaming response
"""
config = RunnableConfig()
input_graph = {
"video_link": request.video_link,
"target_word_count": request.target_word_count,
"language": request.language,
}
return StreamingResponse(
message_generator(input_graph, config),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
},
)