AutoGenScript / src /apis /routers /gen_script.py
ABAO77's picture
Upload 60 files
172064c verified
raw
history blame
6.86 kB
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables import RunnableConfig
from src.agents.agent_transcript.flow import script_writer_agent
from src.utils.logger import logger
from pydantic import BaseModel
import json
import asyncio
class GenScriptRequest(BaseModel):
video_link: str
target_word_count: int = 50000 # Default 2500 words
router = APIRouter()
async def message_generator(
input_graph: dict,
config: RunnableConfig,
):
try:
last_output_state = None
try:
async for event in script_writer_agent.astream(
input=input_graph, stream_mode=["messages", "values"], config=config
):
try:
event_type, event_message = event
logger.info(f"Event type: {event_type}")
if event_type == "messages":
message, metadata = event_message
if isinstance(message, AIMessageChunk):
# Stream AI message chunks
node = metadata.get("node")
chunk_data = {
"type": "message_chunk",
"content": message.content,
"metadata": metadata,
"node_step": node,
}
logger.info(f"Chunk data: {chunk_data}")
yield f"data: {json.dumps(chunk_data)}\n\n"
elif event_type == "values":
# Stream state updates
state_data = {"type": "state_update", "state": event_message}
last_output_state = event_message
# Handle specific data extractions
if "transcript" in event_message and event_message["transcript"]:
transcript_data = {
"type": "transcript_extracted",
"transcript": event_message["transcript"][:500] + "..." if len(event_message["transcript"]) > 500 else event_message["transcript"],
"full_length": len(event_message["transcript"])
}
yield f"data: {json.dumps(transcript_data)}\n\n"
if "comment" in event_message and event_message["comment"]:
comment_data = {
"type": "comment_extracted",
"comment": event_message["comment"][:500] + "..." if len(event_message["comment"]) > 500 else event_message["comment"],
"full_length": len(event_message["comment"])
}
yield f"data: {json.dumps(comment_data)}\n\n"
if "script_count" in event_message:
script_count_data = {
"type": "script_count_calculated",
"script_count": event_message["script_count"],
"target_word_count": event_message.get("target_word_count", 8000)
}
yield f"data: {json.dumps(script_count_data)}\n\n"
# Handle individual script updates
if "script_writer_response" in event_message and "current_script_index" in event_message:
current_scripts = event_message["script_writer_response"]
current_index = event_message["current_script_index"]
script_count = event_message.get("script_count", 10)
if current_scripts:
individual_script_data = {
"type": "individual_script",
"script_index": current_index,
"script_content": current_scripts[-1] if current_scripts else "",
"progress": f"{current_index}/{script_count}",
"scripts": current_scripts
}
yield f"data: {json.dumps(individual_script_data)}\n\n"
yield f"data: {json.dumps(state_data, default=str)}\n\n"
except Exception as e:
logger.error(f"Error processing event: {e}")
error_data = {"type": "error", "message": str(e)}
yield f"data: {json.dumps(error_data)}\n\n"
except Exception as e:
logger.error(f"Error in streaming: {e}")
error_data = {"type": "error", "message": str(e)}
yield f"data: {json.dumps(error_data)}\n\n"
# Send final result
if last_output_state:
final_data = {
"type": "final_result",
"scripts": last_output_state.get("script_writer_response", []),
"total_scripts": len(
last_output_state.get("script_writer_response", [])
),
}
yield f"data: {json.dumps(final_data, default=str)}\n\n"
except Exception as e:
logger.error(f"Fatal error in message_generator: {e}")
yield f"data: {json.dumps({'type': 'fatal_error', 'message': str(e)})}\n\n"
@router.post("/gen-script")
async def gen_script(request: GenScriptRequest):
"""
Generate scripts with streaming response
"""
config = RunnableConfig()
input_graph = {
"video_link": request.video_link,
"target_word_count": request.target_word_count
}
return StreamingResponse(
message_generator(input_graph, config),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Content-Type": "text/event-stream",
},
)
@router.post("/gen-script-sync")
def gen_script_sync(request: GenScriptRequest):
"""
Generate scripts with synchronous response (non-streaming)
"""
response = script_writer_agent.invoke({
"video_link": request.video_link,
"target_word_count": request.target_word_count
})
return {
"scripts": response.get("script_writer_response", []),
"total_scripts": len(response.get("script_writer_response", [])),
"full_response": response,
}