Spaces:
Paused
Paused
| import json | |
| import asyncio | |
| from typing import List | |
| from typing_extensions import TypedDict | |
| from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate | |
| from langgraph.graph import StateGraph, END | |
| from src.utils.api_key_manager import with_api_manager | |
| from src.helpers.helper import remove_markdown | |
| # Define the Graph State | |
| class GraphState(TypedDict): | |
| initial_prompt: str | |
| plan: str | |
| write_steps: List[dict] | |
| final_json: str | |
| def planning_node(state: GraphState, *, llm) -> GraphState: | |
| print("\n---PLANNING---\n") | |
| initial_prompt = state['initial_prompt'] | |
| plan_template = \ | |
| f"""You need to create a structured JSON based on the following instructions: | |
| {initial_prompt} | |
| Rules: | |
| 1. Outline a multi-step plan (one step per line) that will guide the creation of the final JSON. | |
| 2. You must create the entire plan yourself without asking others to create it for you. | |
| 2. The steps should be as follows: | |
| - Each step should be a high-level task or section of the JSON. | |
| - Check if breaking down each step into smaller, low-level sub-tasks or sections is required | |
| - If yes, ONLY include the sub-steps (one sub-step per line). | |
| 3. The plan should be concise and clear, and each step and sub-step should be distinct. | |
| 4. The plan should be unformatted and in plain text. DO NOT even use bullet points or new lines. | |
| 4. The number of steps should be as less as possible, but still enough to cover ALL sections. | |
| 5. If the user request contains any specific details, include them in the plan. | |
| 6. DO NOT create the final content, just the plan/outline. | |
| 7. DO NOT include any markdown or formatting in the plan.""" | |
| chat_template = ChatPromptTemplate.from_messages([ | |
| HumanMessagePromptTemplate.from_template("{text}"), | |
| ] | |
| ) | |
| prompt = chat_template.invoke({"text": plan_template}) | |
| response = llm.invoke(prompt) | |
| plan = response.content.strip() | |
| # Store plan text in state | |
| state['plan'] = remove_markdown(plan) | |
| print(plan) | |
| return state | |
| def writing_node_sync(state: GraphState, *, llm) -> GraphState: | |
| print("\n---WRITING THE JSON---\n") | |
| initial_prompt = state['initial_prompt'] | |
| plan = state['plan'] | |
| plan = plan.strip() | |
| # Split the plan by lines | |
| plan_lines = plan.split('\n') | |
| # Our final partial JSON objects | |
| partial_jsons: List[dict] = [] | |
| # Return partial JSON. | |
| for idx, step_line in enumerate(plan_lines): | |
| if len(step_line.strip()) > 0: | |
| step_prompt_text = \ | |
| f"""You are creating part {idx+1} of the final JSON document. | |
| User request: | |
| {initial_prompt} | |
| Plan step (outline): | |
| {step_line.strip()} | |
| Rules: | |
| 1. You need to write the JSON data for this step. | |
| 2. The JSON should be structured and valid. | |
| 3. If the user request contains any specific details, include them in the JSON. | |
| 4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit. | |
| 5. Respond ONLY with valid JSON for this step without any markdown or formatting.""" | |
| chat_template = ChatPromptTemplate.from_messages([ | |
| HumanMessagePromptTemplate.from_template("{text}"), | |
| ] | |
| ) | |
| prompt = chat_template.invoke({"text": step_prompt_text}) | |
| response = llm.invoke(prompt) | |
| step_result = response.content.strip() | |
| # Attempt to parse the partial JSON | |
| try: | |
| cleaned_result = remove_markdown(step_result) | |
| partial_obj = json.loads(cleaned_result) | |
| except json.JSONDecodeError: | |
| # If the model didn't produce valid JSON, throw an error | |
| raise Exception(f"Failed to parse JSON data for step {idx+1}") | |
| # print(f"Step {idx+1} JSON:\n{json.dumps(partial_obj, indent=2)}\n") | |
| # Add the partial JSON to the list | |
| partial_jsons.append(partial_obj) | |
| # Save all partial JSON in the state | |
| state['write_steps'] = partial_jsons | |
| return state | |
| async def writing_node_async(state: GraphState, *, llm) -> GraphState: | |
| async def get_partial_json(idx: int, step_line: str) -> dict: | |
| step_prompt_text = \ | |
| f"""You are creating part {idx+1} of the final JSON document. | |
| User request: | |
| {initial_prompt} | |
| Plan step (outline): | |
| {step_line.strip()} | |
| Rules: | |
| 1. You need to write the JSON data for this step. | |
| 2. The JSON should be structured and valid. | |
| 3. If the user request contains any specific details, include them in the JSON. | |
| 4. If the user request contains the format of the JSON, follow it. If not, create a generic JSON as you see fit. | |
| 5. Respond ONLY with valid JSON for this step without any markdown or formatting.""" | |
| chat_template = ChatPromptTemplate.from_messages([ | |
| HumanMessagePromptTemplate.from_template("{text}"), | |
| ] | |
| ) | |
| prompt = chat_template.invoke({"text": step_prompt_text}) | |
| response = await llm.ainvoke(prompt) | |
| step_result = response.content.strip() | |
| cleaned_result = remove_markdown(step_result) | |
| try: | |
| partial_obj = json.loads(cleaned_result) | |
| except json.JSONDecodeError as e: | |
| raise Exception(f"Failed to parse JSON data for step {idx+1}: {e}") | |
| # print(f"Step {idx+1} JSON:\n{json.dumps(partial_obj, indent=2)}\n") | |
| return partial_obj | |
| print("\n---WRITING THE JSON---\n") | |
| initial_prompt = state['initial_prompt'] | |
| plan = state['plan'].strip() | |
| plan_lines = plan.split('\n') | |
| partial_jsons: List[dict] = [] | |
| # Build tasks for each step | |
| tasks = [] | |
| for idx, line in enumerate(plan_lines): | |
| if len(line.strip()) > 0: | |
| tasks.append(asyncio.create_task(get_partial_json(idx, line))) | |
| # Run them concurrently | |
| partial_jsons = await asyncio.gather(*tasks) | |
| # Store results | |
| state['write_steps'] = list(partial_jsons) | |
| return state | |
| def consolidation_node(state: GraphState) -> GraphState: | |
| print("\n---CONSOLIDATING THE JSON---\n") | |
| plan = state['plan'] | |
| partial_jsons = state['write_steps'] | |
| final_obj = { | |
| "plan": plan, | |
| "steps": partial_jsons | |
| } | |
| # Convert to string | |
| final_json_str = json.dumps(final_obj, ensure_ascii=False, indent=2) | |
| # Store it in the state | |
| state['final_json'] = final_json_str | |
| return state | |
| def create_workflow_sync() -> StateGraph: | |
| workflow = StateGraph(GraphState) | |
| # Add nodes | |
| workflow.add_node("planning_node", planning_node) | |
| workflow.add_node("writing_node", writing_node_sync) | |
| workflow.add_node("consolidation_node", consolidation_node) | |
| # Set entry point | |
| workflow.set_entry_point("planning_node") | |
| # Add edges | |
| workflow.add_edge("planning_node", "writing_node") | |
| workflow.add_edge("writing_node", "consolidation_node") | |
| # Finally, consolidation_node leads to END | |
| workflow.add_edge("consolidation_node", END) | |
| return workflow.compile() | |
| def create_workflow_async() -> StateGraph: | |
| workflow = StateGraph(GraphState) | |
| # Add nodes | |
| workflow.add_node("planning_node", planning_node) | |
| workflow.add_node("writing_node", writing_node_async) | |
| workflow.add_node("consolidation_node", consolidation_node) | |
| # Set entry point | |
| workflow.set_entry_point("planning_node") | |
| # Add edges | |
| workflow.add_edge("planning_node", "writing_node") | |
| workflow.add_edge("writing_node", "consolidation_node") | |
| # Finally, consolidation_node leads to END | |
| workflow.add_edge("consolidation_node", END) | |
| return workflow.compile() | |
| if __name__ == "__main__": | |
| import time | |
| test_instruction = "Write a 1500-word piece on the HBO TV show Westworld, covering major characters, \ | |
| themes of AI and consciousness, and how the story might have continued had it not been cancelled. \ | |
| Include specific details, quotes, and references to the show and its creators.\ | |
| Do not include any spoilers for the climax of the show's final season." | |
| app = create_workflow_async() | |
| # We supply an initial state. | |
| # (We only need 'initial_prompt' here; the other fields will be set by nodes.) | |
| state_input: GraphState = { | |
| "initial_prompt": test_instruction, | |
| "plan": "", | |
| "write_steps": [], | |
| "final_json": "" | |
| } | |
| start = time.time() | |
| final_state = asyncio.run(app.ainvoke(state_input)) | |
| end = time.time() | |
| # The final JSON is in final_state['final_json'] | |
| print("\n===== FINAL JSON OUTPUT =====\n") | |
| print(final_state['final_json']) | |
| print("=============================\n") | |
| print("\n===== PERFOMANCE =====\n") | |
| print(f"Time taken: {end-start:.2f} seconds") | |
| print("======================\n") | |