Spaces:
Sleeping
Sleeping
| from langgraph.graph import StateGraph, END | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain_openai import ChatOpenAI # or your preferred LLM | |
| from pydantic import BaseModel, Field, field_validator | |
| from typing import TypedDict, List | |
| import os | |
| from pydantic import field_validator # needed for custom validation | |
| from dotenv import load_dotenv | |
| load_dotenv() # take environment variables from .env. | |
| # Define the structured output model | |
| class StoryOutput(BaseModel): | |
| """Structured output for the storyteller agent""" | |
| polished_story: str = Field( | |
| description="A refined version of the story with improved flow, grammar, and engagement" | |
| ) | |
| keywords: List[str] = Field( | |
| description="A list of 5-10 key terms that represent the main themes, characters, or concepts", | |
| min_items=3, | |
| max_items=7 | |
| ) | |
| # Define the state structure | |
| class AgentState(TypedDict): | |
| original_story: str | |
| polished_story: str | |
| keywords: List[str] | |
| messages: List[dict] | |
| # Storyteller Agent with Structured Output | |
| class StorytellerAgent: | |
| def __init__(self, llm): | |
| # Create structured LLM with the output model | |
| self.structured_llm = llm.with_structured_output(StoryOutput) | |
| self.system_prompt = """You are a skilled storyteller AI. Your job is to take raw, confessional-style stories and transform them into emotionally engaging, narrative-driven pieces. The rewritten story should: | |
| 1. Preserve the original events and meaning but present them in a captivating way. | |
| 2. Use character names (instead of “my brother,” “my sister”) to make the story feel alive. | |
| 3. Add dialogue, atmosphere, and inner thoughts to create tension and immersion. | |
| 4. Write in a third-person narrative style, as if the story is being shared by an observer. | |
| 5. Maintain a natural, human voice — conversational, reflective, and vivid. | |
| 6. Balance realism with storytelling techniques (scene-setting, emotional beats, sensory details). | |
| 7. Keep the length roughly 2–3x the original input, ensuring it feels like a polished story. | |
| Your goal is to make the reader feel emotionally invested, as though they’re listening to someone recounting a deeply personal and dramatic life event. | |
| """ | |
| def __call__(self, state: AgentState) -> AgentState: | |
| # Prepare messages for the structured LLM | |
| messages = [ | |
| SystemMessage(content=self.system_prompt), | |
| HumanMessage(content=f"Please polish this story and extract keywords:\n\n{state['original_story']}") | |
| ] | |
| # Get structured response | |
| response: StoryOutput = self.structured_llm.invoke(messages) | |
| # Update state with structured output | |
| state["polished_story"] = response.polished_story | |
| state["keywords"] = response.keywords | |
| state["messages"].append({ | |
| "role": "assistant", | |
| "content": f"Polished story and extracted {len(response.keywords)} keywords" | |
| }) | |
| return state | |
| # Create the graph functions | |
| def create_storyteller_graph(enhanced=False): | |
| llm = ChatOpenAI( | |
| model='gpt-4o', | |
| api_key=os.getenv('OPENAI_API_KEY'), | |
| temperature=0.2, | |
| max_tokens=10000 | |
| ) | |
| # Choose agent type | |
| storyteller = StorytellerAgent(llm) | |
| # Create the graph | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("storyteller", storyteller) | |
| workflow.set_entry_point("storyteller") | |
| workflow.add_edge("storyteller", END) | |
| return workflow.compile() | |
| # Usage functions | |
| def process_story(original_story: str, enhanced=False): | |
| graph = create_storyteller_graph(enhanced) | |
| initial_state = { | |
| "original_story": original_story, | |
| "polished_story": "", | |
| "keywords": [], | |
| "messages": [] | |
| } | |
| result = graph.invoke(initial_state) | |
| return { | |
| "polished_story": result["polished_story"], | |
| "keywords": result["keywords"] | |
| } | |
| # Example with validation | |
| class ValidatedStoryOutput(BaseModel): | |
| """Story output with additional validation""" | |
| polished_story: str = Field( | |
| description="Enhanced story", | |
| min_length=50 # Ensure minimum story length | |
| ) | |
| keywords: List[str] = Field( | |
| description="Story keywords", | |
| min_items=3, | |
| max_items=7 | |
| ) | |
| def validate_story_quality(cls, v: str): | |
| """Custom validation for story content""" | |
| if len(v.split()) < 10: | |
| raise ValueError("Polished story must contain at least 10 words") | |
| return v |