Spaces:
Build error
Build error
kamaleswar Mohanta commited on
Commit ·
c1fa745
1
Parent(s): 90dcfc0
changed blog generation to HIN orchestrator worker.
Browse filesneed to improve with use of tools for updated data on topics
- src/langgraphagenticai/LLMS/__pycache__/chatgptllm.cpython-312.pyc +0 -0
- src/langgraphagenticai/LLMS/__pycache__/geminillm.cpython-312.pyc +0 -0
- src/langgraphagenticai/LLMS/__pycache__/groqllm.cpython-312.pyc +0 -0
- src/langgraphagenticai/__pycache__/main.cpython-312.pyc +0 -0
- src/langgraphagenticai/graph/__pycache__/graph_builder.cpython-312.pyc +0 -0
- src/langgraphagenticai/graph/graph_builder.py +14 -62
- src/langgraphagenticai/nodes/__pycache__/basic_chatbot_node.cpython-312.pyc +0 -0
- src/langgraphagenticai/nodes/__pycache__/blog_generation_node.cpython-312.pyc +0 -0
- src/langgraphagenticai/nodes/__pycache__/chatbot_with_Tool_node.cpython-312.pyc +0 -0
- src/langgraphagenticai/nodes/blog_generation_node.py +106 -213
- src/langgraphagenticai/state/__pycache__/state.cpython-312.pyc +0 -0
- src/langgraphagenticai/state/state.py +29 -24
- src/langgraphagenticai/ui/streamlitui/__pycache__/loadui.cpython-312.pyc +0 -0
- src/langgraphagenticai/ui/streamlitui/display_result.py +71 -71
src/langgraphagenticai/LLMS/__pycache__/chatgptllm.cpython-312.pyc
CHANGED
|
Binary files a/src/langgraphagenticai/LLMS/__pycache__/chatgptllm.cpython-312.pyc and b/src/langgraphagenticai/LLMS/__pycache__/chatgptllm.cpython-312.pyc differ
|
|
|
src/langgraphagenticai/LLMS/__pycache__/geminillm.cpython-312.pyc
CHANGED
|
Binary files a/src/langgraphagenticai/LLMS/__pycache__/geminillm.cpython-312.pyc and b/src/langgraphagenticai/LLMS/__pycache__/geminillm.cpython-312.pyc differ
|
|
|
src/langgraphagenticai/LLMS/__pycache__/groqllm.cpython-312.pyc
CHANGED
|
Binary files a/src/langgraphagenticai/LLMS/__pycache__/groqllm.cpython-312.pyc and b/src/langgraphagenticai/LLMS/__pycache__/groqllm.cpython-312.pyc differ
|
|
|
src/langgraphagenticai/__pycache__/main.cpython-312.pyc
CHANGED
|
Binary files a/src/langgraphagenticai/__pycache__/main.cpython-312.pyc and b/src/langgraphagenticai/__pycache__/main.cpython-312.pyc differ
|
|
|
src/langgraphagenticai/graph/__pycache__/graph_builder.cpython-312.pyc
CHANGED
|
Binary files a/src/langgraphagenticai/graph/__pycache__/graph_builder.cpython-312.pyc and b/src/langgraphagenticai/graph/__pycache__/graph_builder.cpython-312.pyc differ
|
|
|
src/langgraphagenticai/graph/graph_builder.py
CHANGED
|
@@ -145,73 +145,25 @@ class GraphBuilder:
|
|
| 145 |
graph_builder = StateGraph(state_schema=State)
|
| 146 |
blog_node = BlogGenerationNode(self.llm)
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
result["structure"] = ", ".join(self.validate_and_standardize_structure(full_input))
|
| 155 |
-
logger.info(f"Validated structure: {result['structure']}")
|
| 156 |
-
return result
|
| 157 |
-
|
| 158 |
-
# Add nodes with validated user input
|
| 159 |
-
graph_builder.add_node("user_input", user_input_with_validation)
|
| 160 |
-
graph_builder.add_node("outline_generator", blog_node.outline_generator)
|
| 161 |
-
graph_builder.add_node("outline_review", blog_node.outline_review)
|
| 162 |
-
graph_builder.add_node("draft_generator", blog_node.draft_generator)
|
| 163 |
-
graph_builder.add_node("draft_review", blog_node.draft_review)
|
| 164 |
-
graph_builder.add_node("web_search", blog_node.web_search)
|
| 165 |
graph_builder.add_node("revision_generator", blog_node.revision_generator)
|
| 166 |
|
| 167 |
-
#
|
| 168 |
graph_builder.add_edge(START, "user_input")
|
| 169 |
-
graph_builder.add_edge("user_input", "
|
| 170 |
-
graph_builder.
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
logger.error("Invalid state provided")
|
| 176 |
-
return "outline_review"
|
| 177 |
-
feedback = state.get("outline_feedback", "")
|
| 178 |
-
if not feedback: # No feedback yet, stay at review
|
| 179 |
-
return "outline_review"
|
| 180 |
-
return "draft_generator" if feedback == "approved" else "outline_generator"
|
| 181 |
-
|
| 182 |
-
graph_builder.add_conditional_edges(
|
| 183 |
-
"outline_review",
|
| 184 |
-
determine_outline_review,
|
| 185 |
-
{"outline_review": "outline_review", "outline_generator": "outline_generator", "draft_generator": "draft_generator"}
|
| 186 |
-
)
|
| 187 |
-
|
| 188 |
-
# Draft generator conditional edge for web search
|
| 189 |
-
graph_builder.add_conditional_edges(
|
| 190 |
-
"draft_generator",
|
| 191 |
-
lambda state: "web_search" if state.get("needs_facts", False) else "draft_review",
|
| 192 |
-
{"web_search": "web_search", "draft_review": "draft_review"}
|
| 193 |
-
)
|
| 194 |
-
graph_builder.add_edge("web_search", "draft_generator")
|
| 195 |
-
|
| 196 |
-
# Conditional edge for draft review based on button feedback
|
| 197 |
-
def determine_draft_review(state):
|
| 198 |
-
if not state:
|
| 199 |
-
logger.error("Invalid state provided")
|
| 200 |
-
return "draft_review"
|
| 201 |
-
feedback = state.get("draft_feedback", "")
|
| 202 |
-
if not feedback: # No feedback yet, stay at review
|
| 203 |
-
return "draft_review"
|
| 204 |
-
return END if feedback == "approved" else "revision_generator"
|
| 205 |
-
|
| 206 |
-
graph_builder.add_conditional_edges(
|
| 207 |
-
"draft_review",
|
| 208 |
-
determine_draft_review,
|
| 209 |
-
{"draft_review": "draft_review", "revision_generator": "revision_generator", END: END}
|
| 210 |
-
)
|
| 211 |
-
graph_builder.add_edge("revision_generator", "draft_review")
|
| 212 |
|
| 213 |
# Compile with interrupts at review nodes
|
| 214 |
-
return graph_builder.compile(interrupt_before=["
|
| 215 |
except Exception as e:
|
| 216 |
logger.error(f"Error building blog generation graph: {e}")
|
| 217 |
return None
|
|
|
|
| 145 |
graph_builder = StateGraph(state_schema=State)
|
| 146 |
blog_node = BlogGenerationNode(self.llm)
|
| 147 |
|
| 148 |
+
# # Add nodes
|
| 149 |
+
graph_builder.add_node("user_input", blog_node.user_input)
|
| 150 |
+
graph_builder.add_node("orchestrator", blog_node.orchestrator)
|
| 151 |
+
graph_builder.add_node("llm_call", blog_node.llm_call)
|
| 152 |
+
graph_builder.add_node("synthesizer", blog_node.synthesizer)
|
| 153 |
+
graph_builder.add_node("feedback_collector", blog_node.feedback_collector)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
graph_builder.add_node("revision_generator", blog_node.revision_generator)
|
| 155 |
|
| 156 |
+
# Add edges
|
| 157 |
graph_builder.add_edge(START, "user_input")
|
| 158 |
+
graph_builder.add_edge("user_input", "orchestrator")
|
| 159 |
+
graph_builder.add_conditional_edges("orchestrator", lambda state: blog_node.assign_workers(state), ["llm_call"])
|
| 160 |
+
graph_builder.add_edge("llm_call", "synthesizer")
|
| 161 |
+
graph_builder.add_edge("synthesizer", "feedback_collector")
|
| 162 |
+
graph_builder.add_conditional_edges("feedback_collector", blog_node.route_feedback, {"revision_generator": "revision_generator", END: END})
|
| 163 |
+
graph_builder.add_edge("revision_generator", "synthesizer") # Loop back to synthesize revised sections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# Compile with interrupts at review nodes
|
| 166 |
+
return graph_builder.compile(interrupt_before=["feedback_collector"], checkpointer=self.memory)
|
| 167 |
except Exception as e:
|
| 168 |
logger.error(f"Error building blog generation graph: {e}")
|
| 169 |
return None
|
src/langgraphagenticai/nodes/__pycache__/basic_chatbot_node.cpython-312.pyc
CHANGED
|
Binary files a/src/langgraphagenticai/nodes/__pycache__/basic_chatbot_node.cpython-312.pyc and b/src/langgraphagenticai/nodes/__pycache__/basic_chatbot_node.cpython-312.pyc differ
|
|
|
src/langgraphagenticai/nodes/__pycache__/blog_generation_node.cpython-312.pyc
CHANGED
|
Binary files a/src/langgraphagenticai/nodes/__pycache__/blog_generation_node.cpython-312.pyc and b/src/langgraphagenticai/nodes/__pycache__/blog_generation_node.cpython-312.pyc differ
|
|
|
src/langgraphagenticai/nodes/__pycache__/chatbot_with_Tool_node.cpython-312.pyc
CHANGED
|
Binary files a/src/langgraphagenticai/nodes/__pycache__/chatbot_with_Tool_node.cpython-312.pyc and b/src/langgraphagenticai/nodes/__pycache__/chatbot_with_Tool_node.cpython-312.pyc differ
|
|
|
src/langgraphagenticai/nodes/blog_generation_node.py
CHANGED
|
@@ -1,41 +1,41 @@
|
|
| 1 |
-
# src/langgraphagenticai/nodes/blog_generation_node.py
|
| 2 |
-
from src.langgraphagenticai.state.state import State, Section
|
| 3 |
-
from pydantic import BaseModel, Field
|
| 4 |
-
from typing import List, Dict, Optional
|
| 5 |
-
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
| 6 |
-
from langchain_core.prompts import ChatPromptTemplate
|
| 7 |
-
from src.langgraphagenticai.tools.search_tool import get_tools
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
import streamlit as st
|
| 10 |
import json
|
| 11 |
|
|
|
|
| 12 |
# Configure logging
|
| 13 |
logging.basicConfig(level=logging.INFO)
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
-
class Sections(BaseModel):
|
| 17 |
-
sections: List[Section] = Field(description="List of sections for the blog report.")
|
| 18 |
-
|
| 19 |
class BlogGenerationNode:
|
| 20 |
def __init__(self, model):
|
| 21 |
"""Initialize the BlogGenerationNode with an LLM."""
|
| 22 |
-
self.planner = model.with_structured_output(Sections)
|
| 23 |
self.llm = model
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
if not self.search_tool:
|
| 27 |
-
logger.warning("No search tool available; web search will be skipped.")
|
| 28 |
-
self.draft_prompt = ChatPromptTemplate.from_messages([
|
| 29 |
-
("system", "You are a blog writer. Generate a section starting with a markdown heading (##) matching the section name. "
|
| 30 |
-
"Use the provided description and web search results to guide the content. Use markdown formatting (e.g., paragraphs, lists). "
|
| 31 |
-
"Focus only on the current section and avoid repeating prior content unless directly relevant. "
|
| 32 |
-
"Web search results: {search_results}"),
|
| 33 |
-
("human", "Section name: {name}\nDescription: {description}")
|
| 34 |
-
])
|
| 35 |
-
self.feedback_prompt = ChatPromptTemplate.from_messages([
|
| 36 |
-
("system", "Refine the blog sections based on feedback: {feedback}. Keep prior content and adjust as requested."),
|
| 37 |
-
("placeholder", "{messages}")
|
| 38 |
-
])
|
| 39 |
|
| 40 |
def user_input(self, state: State) -> dict:
|
| 41 |
"""Handle user input (Node A)."""
|
|
@@ -57,202 +57,95 @@ class BlogGenerationNode:
|
|
| 57 |
logger.info(f"Parsed requirements: {result}")
|
| 58 |
return result
|
| 59 |
|
| 60 |
-
def outline_generator(self, state: State) -> dict:
|
| 61 |
-
"""Generate the blog outline based on user requirements (Node B)."""
|
| 62 |
-
logger.info(f"Executing outline_generator with state: {state}")
|
| 63 |
-
topic = state.get("topic", "No topic provided")
|
| 64 |
-
objective = state.get("objective", "Informative")
|
| 65 |
-
target_audience = state.get("target_audience", "General Audience")
|
| 66 |
-
tone_style = state.get("tone_style", "Casual")
|
| 67 |
-
word_count = state.get("word_count", 1000)
|
| 68 |
-
structure = state.get("structure", "Introduction, Main Points, Conclusion")
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
| 71 |
section_count = len(structure_list)
|
| 72 |
-
|
| 73 |
prompt = (
|
| 74 |
-
f"Generate a structured plan for a blog report with exactly {section_count} sections
|
| 75 |
-
f"Ensure the content is relevant to the topic: {topic}. "
|
| 76 |
-
f"The blog's objective is {objective}, aimed at {target_audience},
|
| 77 |
-
f"Target a word count of {word_count} words. "
|
| 78 |
-
f"Use this exact structure and section names: {', '.join(structure_list)}.
|
| 79 |
-
f"
|
| 80 |
)
|
| 81 |
-
|
| 82 |
SystemMessage(content=prompt),
|
| 83 |
-
HumanMessage(content=f"Topic: {topic}")
|
| 84 |
-
]
|
| 85 |
-
logger.info(f"
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
state["messages"].append(AIMessage(content=outline_message))
|
| 107 |
-
state["sections"] = report_sections.sections
|
| 108 |
-
|
| 109 |
-
# Return both sections and messages in the result
|
| 110 |
-
return {
|
| 111 |
-
"sections": report_sections.sections,
|
| 112 |
-
"messages": state["messages"]
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
except Exception as e:
|
| 116 |
-
logger.error(f"LLM invocation failed: {e}")
|
| 117 |
-
state["messages"].append(AIMessage(content=f"Error generating outline: {e}"))
|
| 118 |
-
return {"sections": [], "messages": state["messages"]}
|
| 119 |
-
|
| 120 |
-
def outline_review(self, state: State) -> dict:
|
| 121 |
-
"""Handle human review of the outline (Outline Review-Human, Node C)."""
|
| 122 |
-
logger.info(f"Executing outline_review with state: {state}")
|
| 123 |
-
feedback = st.session_state.get("outline_feedback", "")
|
| 124 |
-
state["outline_feedback"] = feedback # Sync feedback from Streamlit state
|
| 125 |
-
logger.info(f"Outline feedback set: {feedback}")
|
| 126 |
-
return {"outline_approved": feedback == "approved", "outline_feedback": feedback}
|
| 127 |
-
|
| 128 |
-
def web_search(self, state: State) -> dict:
|
| 129 |
-
"""Fetch web search results for each section (Web Search/Scraping, Node I)."""
|
| 130 |
-
logger.info(f"Executing web_search with state: {state}")
|
| 131 |
-
search_results = {}
|
| 132 |
-
if not self.search_tool:
|
| 133 |
-
logger.info("No search tool available; skipping web search.")
|
| 134 |
-
return {"search_results": {s['name']: "No search tool configured." for s in state["sections"]}}
|
| 135 |
-
for section in state["sections"]:
|
| 136 |
-
query = f"{state['topic']} {section['name']}"
|
| 137 |
-
logger.info(f"Searching for: {query}")
|
| 138 |
-
try:
|
| 139 |
-
results = self.search_tool.invoke({"query": query})
|
| 140 |
-
search_results[section['name']] = "\n".join([r.get("content", "") for r in results])
|
| 141 |
-
except Exception as e:
|
| 142 |
-
logger.error(f"Web search failed for {query}: {e}")
|
| 143 |
-
search_results[section['name']] = "No data available due to search error."
|
| 144 |
-
logger.info(f"Search results: {search_results}")
|
| 145 |
-
return {"search_results": search_results}
|
| 146 |
-
|
| 147 |
-
def draft_generator(self, state: State) -> dict:
|
| 148 |
-
"""Generate the initial draft using search results (Draft Generator-LLM, Node D)."""
|
| 149 |
-
logger.info(f"Executing draft_generator with state: {state}")
|
| 150 |
-
if not state.get("search_results"):
|
| 151 |
-
logger.info("No search results found, triggering web search.")
|
| 152 |
-
return {"needs_facts": True}
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
|
| 157 |
-
#
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
}
|
| 164 |
-
messages = self.draft_prompt.format_messages(**prompt_inputs)
|
| 165 |
-
logger.info(f"Draft prompt messages: {messages}")
|
| 166 |
-
try:
|
| 167 |
-
content = self.llm.invoke(messages).content
|
| 168 |
-
# Ensure consistent section formatting
|
| 169 |
-
section_content = self._format_section_content(section['name'], content)
|
| 170 |
-
completed_sections.append(section_content)
|
| 171 |
-
final_draft.append(section_content)
|
| 172 |
-
except Exception as e:
|
| 173 |
-
logger.error(f"Failed to generate section {section['name']}: {e}")
|
| 174 |
-
error_content = f"## {section['name']}\nError: {e}"
|
| 175 |
-
completed_sections.append(error_content)
|
| 176 |
-
final_draft.append(error_content)
|
| 177 |
-
|
| 178 |
-
# Join sections with clear separation and formatting
|
| 179 |
-
draft = "\n\n".join(final_draft)
|
| 180 |
-
|
| 181 |
-
# Create a formatted display version with consistent styling
|
| 182 |
-
display_content = self._format_display_content(draft)
|
| 183 |
-
|
| 184 |
-
# Update state
|
| 185 |
-
state["completed_sections"] = completed_sections
|
| 186 |
-
state["messages"].append(AIMessage(content=display_content))
|
| 187 |
-
|
| 188 |
-
logger.info(f"Generated draft with {len(completed_sections)} sections")
|
| 189 |
-
return {
|
| 190 |
-
"completed_sections": completed_sections,
|
| 191 |
-
"draft_content": draft,
|
| 192 |
-
"needs_facts": False
|
| 193 |
-
}
|
| 194 |
|
| 195 |
-
def
|
| 196 |
-
"""
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
# Ensure section starts with proper heading
|
| 201 |
-
if not content.startswith(f"## {section_name}"):
|
| 202 |
-
content = f"## {section_name}\n\n{content}"
|
| 203 |
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
|
| 208 |
-
for
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
formatted_paragraphs.append(para)
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
feedback = st.session_state.get("draft_feedback", "")
|
| 231 |
-
state["draft_feedback"] = feedback # Sync feedback from Streamlit state
|
| 232 |
-
logger.info(f"Draft feedback set: {feedback}")
|
| 233 |
-
if feedback == "approved":
|
| 234 |
-
final_draft = "\n\n---\n\n".join(state["completed_sections"])
|
| 235 |
-
state["messages"].append(AIMessage(content=f"Final approved draft:\n{final_draft}"))
|
| 236 |
-
logger.info(f"Updated state with final draft: {state}")
|
| 237 |
-
return {"draft_approved": feedback == "approved", "feedback": feedback}
|
| 238 |
-
|
| 239 |
-
def revision_generator(self, state: State) -> dict:
|
| 240 |
-
"""Refine the draft based on human feedback (Revision Generator-LLM, Node F)."""
|
| 241 |
-
logger.info(f"Executing revision_generator with state: {state}")
|
| 242 |
-
feedback = state.get("feedback", "")
|
| 243 |
-
prompt_inputs = {
|
| 244 |
-
"feedback": feedback,
|
| 245 |
-
"messages": state["messages"]
|
| 246 |
-
}
|
| 247 |
-
messages = self.feedback_prompt.format_messages(**prompt_inputs)
|
| 248 |
-
logger.info(f"Feedback prompt messages: {messages}")
|
| 249 |
-
try:
|
| 250 |
-
refined = self.llm.invoke(messages).content.split("\n\n---\n\n")
|
| 251 |
-
draft = "\n\n---\n\n".join(refined)
|
| 252 |
-
state["messages"].append(AIMessage(content=f"Refined draft:\n{draft}\nPlease review again with the buttons below."))
|
| 253 |
-
logger.info(f"Updated state with refined draft: {state}")
|
| 254 |
-
return {"completed_sections": refined, "feedback": feedback}
|
| 255 |
-
except Exception as e:
|
| 256 |
-
logger.error(f"Failed to refine draft: {e}")
|
| 257 |
-
state["messages"].append(AIMessage(content=f"Error refining draft: {e}"))
|
| 258 |
-
return {"completed_sections": state["completed_sections"], "feedback": feedback}
|
|
|
|
| 1 |
+
# # src/langgraphagenticai/nodes/blog_generation_node.py
|
| 2 |
+
# from src.langgraphagenticai.state.state import State, Section
|
| 3 |
+
# from pydantic import BaseModel, Field
|
| 4 |
+
# from typing import List, Dict, Optional
|
| 5 |
+
# from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
| 6 |
+
# from langchain_core.prompts import ChatPromptTemplate
|
| 7 |
+
# from src.langgraphagenticai.tools.search_tool import get_tools
|
| 8 |
+
# import logging
|
| 9 |
+
# import streamlit as st
|
| 10 |
+
# import json
|
| 11 |
+
|
| 12 |
+
# # Configure logging
|
| 13 |
+
# logging.basicConfig(level=logging.INFO)
|
| 14 |
+
# logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
import logging
|
| 19 |
+
from langgraph.graph import StateGraph, START, END
|
| 20 |
+
from langgraph.constants import Send
|
| 21 |
+
from src.langgraphagenticai.state.state import State, WorkerState, Sections, Section # Import from state.py
|
| 22 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
| 23 |
+
from IPython.display import Markdown, display
|
| 24 |
+
from PIL import Image
|
| 25 |
import streamlit as st
|
| 26 |
import json
|
| 27 |
|
| 28 |
+
|
| 29 |
# Configure logging
|
| 30 |
logging.basicConfig(level=logging.INFO)
|
| 31 |
logger = logging.getLogger(__name__)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
class BlogGenerationNode:
|
| 34 |
def __init__(self, model):
|
| 35 |
"""Initialize the BlogGenerationNode with an LLM."""
|
|
|
|
| 36 |
self.llm = model
|
| 37 |
+
self.planner = model.with_structured_output(Sections)
|
| 38 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
def user_input(self, state: State) -> dict:
|
| 41 |
"""Handle user input (Node A)."""
|
|
|
|
| 57 |
logger.info(f"Parsed requirements: {result}")
|
| 58 |
return result
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
def orchestrator(self,state: State) -> dict:
|
| 62 |
+
"""Orchestrator that generates a plan for the report."""
|
| 63 |
+
logger.info(f"Executing orchestrator with state: {state}")
|
| 64 |
+
structure_list = [s.strip() for s in state["structure"].split(",")]
|
| 65 |
section_count = len(structure_list)
|
| 66 |
+
|
| 67 |
prompt = (
|
| 68 |
+
f"Generate a structured plan for a blog report with exactly {section_count} sections. "
|
| 69 |
+
f"Ensure the content is relevant to the topic: {state['topic']}. "
|
| 70 |
+
f"The blog's objective is {state['objective']}, aimed at {state['target_audience']}, "
|
| 71 |
+
f"with a {state['tone_style']} tone. Target a word count of {state['word_count']} words. "
|
| 72 |
+
f"Use this exact structure and section names: {', '.join(structure_list)}. "
|
| 73 |
+
f"Do not add extra sections or change the names."
|
| 74 |
)
|
| 75 |
+
report_sections = self.planner.invoke([
|
| 76 |
SystemMessage(content=prompt),
|
| 77 |
+
HumanMessage(content=f"Topic: {state['topic']}")
|
| 78 |
+
])
|
| 79 |
+
logger.info(f"Report Sections: {report_sections}")
|
| 80 |
+
return {"sections": report_sections.sections}
|
| 81 |
+
|
| 82 |
+
def llm_call(self,state: WorkerState) -> dict:
|
| 83 |
+
"""Worker writes a section of the report."""
|
| 84 |
+
section = self.llm.invoke([
|
| 85 |
+
SystemMessage(content="Write a report section following the provided name and description. Include no preamble for each section. Use markdown formatting."),
|
| 86 |
+
HumanMessage(content=f"Here is the section name: {state['section'].name} and description: {state['section'].description}")
|
| 87 |
+
])
|
| 88 |
+
return {"completed_sections": [section.content]}
|
| 89 |
+
|
| 90 |
+
def synthesizer(self,state: State) -> dict:
|
| 91 |
+
"""Synthesize full report from sections."""
|
| 92 |
+
completed_sections = state["completed_sections"]
|
| 93 |
+
final_report = "\n\n---\n\n".join(completed_sections)
|
| 94 |
+
logger.info(f"Synthesized report: {final_report}")
|
| 95 |
+
return {"final_report": final_report}
|
| 96 |
+
|
| 97 |
+
def feedback_collector(self, state: State) -> dict:
|
| 98 |
+
"""Collect human feedback on the draft using Streamlit interface."""
|
| 99 |
+
logger.info("Executing feedback_collector")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
# Get feedback from Streamlit session state
|
| 102 |
+
feedback = st.session_state.get("draft_feedback", "")
|
| 103 |
|
| 104 |
+
# Check if feedback exists
|
| 105 |
+
if feedback:
|
| 106 |
+
logger.info(f"Collected feedback: {feedback}")
|
| 107 |
+
# Set draft_approved in state
|
| 108 |
+
state["draft_approved"] = feedback == "approved"
|
| 109 |
+
return {
|
| 110 |
+
"feedback": feedback,
|
| 111 |
+
"draft_approved": state["draft_approved"]
|
| 112 |
+
}
|
| 113 |
+
else:
|
| 114 |
+
logger.info("No feedback provided yet")
|
| 115 |
+
# Set draft_approved in state
|
| 116 |
+
state["draft_approved"] = False
|
| 117 |
+
return {
|
| 118 |
+
"feedback": "",
|
| 119 |
+
"draft_approved": state["draft_approved"]
|
| 120 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
def revision_generator(self,state: State) -> dict:
|
| 123 |
+
"""Revise the report based on human feedback."""
|
| 124 |
+
if state["draft_approved"]:
|
| 125 |
+
return {"final_report": state["final_report"]} # No revision needed
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
feedback = state["feedback"]
|
| 128 |
+
completed_sections = state["completed_sections"]
|
| 129 |
+
revised_sections = []
|
| 130 |
|
| 131 |
+
for section in completed_sections:
|
| 132 |
+
revised_content = self.llm.invoke([
|
| 133 |
+
SystemMessage(content=f"Refine this section based on feedback: {feedback}. Keep prior content and adjust as requested. Use markdown formatting."),
|
| 134 |
+
HumanMessage(content=section)
|
| 135 |
+
])
|
| 136 |
+
revised_sections.append(revised_content.content)
|
|
|
|
| 137 |
|
| 138 |
+
logger.info(f"Revised sections: {revised_sections}")
|
| 139 |
+
return {"completed_sections": revised_sections}
|
| 140 |
+
|
| 141 |
+
# Conditional edge function to create llm_call workers
|
| 142 |
+
def assign_workers(self,state: State):
|
| 143 |
+
"""Assign a worker to each section in the plan."""
|
| 144 |
+
return [Send("llm_call", {"section": s}) for s in state["sections"]]
|
| 145 |
+
|
| 146 |
+
# Conditional edge for feedback loop
|
| 147 |
+
def route_feedback(self,state: State):
|
| 148 |
+
"""Route based on whether draft is approved."""
|
| 149 |
+
if state["draft_approved"]:
|
| 150 |
+
return END
|
| 151 |
+
return "revision_generator"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/langgraphagenticai/state/__pycache__/state.cpython-312.pyc
CHANGED
|
Binary files a/src/langgraphagenticai/state/__pycache__/state.cpython-312.pyc and b/src/langgraphagenticai/state/__pycache__/state.cpython-312.pyc differ
|
|
|
src/langgraphagenticai/state/state.py
CHANGED
|
@@ -1,29 +1,34 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
from
|
| 4 |
from langgraph.graph.message import add_messages
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
name: str
|
| 9 |
-
description: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
class State(TypedDict, total=False):
|
| 12 |
-
"""State schema for the LangGraph workflow, with all fields optional."""
|
| 13 |
messages: Annotated[list, add_messages] # Chat history including user inputs and AI responses
|
| 14 |
|
| 15 |
-
topic: str #
|
| 16 |
-
objective: str #
|
| 17 |
-
target_audience: str #
|
| 18 |
-
tone_style: str # Tone
|
| 19 |
-
word_count: int #
|
| 20 |
-
structure: str #
|
| 21 |
-
sections: List[Section] # List of
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
| 1 |
+
from typing import Annotated, List, TypedDict
|
| 2 |
+
import operator
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
from langgraph.graph.message import add_messages
|
| 5 |
|
| 6 |
+
# Schema for structured output to use in planning
|
| 7 |
+
class Section(BaseModel):
|
| 8 |
+
name: str = Field(description="Name for this section of the report.")
|
| 9 |
+
description: str = Field(description="Brief overview of the main topics and concepts to be covered in this section.")
|
| 10 |
+
|
| 11 |
+
class Sections(BaseModel):
|
| 12 |
+
sections: List[Section] = Field(description="Sections of the report.")
|
| 13 |
+
|
| 14 |
+
# Graph state
|
| 15 |
+
class State(TypedDict):
|
| 16 |
|
|
|
|
|
|
|
| 17 |
messages: Annotated[list, add_messages] # Chat history including user inputs and AI responses
|
| 18 |
|
| 19 |
+
topic: str # Report topic from user input
|
| 20 |
+
objective: str # Objective from user input
|
| 21 |
+
target_audience: str # Target audience from user input
|
| 22 |
+
tone_style: str # Tone/style from user input
|
| 23 |
+
word_count: int # Word count from user input
|
| 24 |
+
structure: str # Structure from user input
|
| 25 |
+
sections: List[Section] # List of report sections
|
| 26 |
+
completed_sections: Annotated[List[str], operator.add] # All workers write to this key in parallel
|
| 27 |
+
final_report: str # Final report
|
| 28 |
+
feedback: str # Human feedback
|
| 29 |
+
draft_approved: bool # Whether the draft is approved
|
| 30 |
+
|
| 31 |
+
# Worker state
|
| 32 |
+
class WorkerState(TypedDict):
|
| 33 |
+
section: Section
|
| 34 |
+
completed_sections: Annotated[List[str], operator.add]
|
src/langgraphagenticai/ui/streamlitui/__pycache__/loadui.cpython-312.pyc
CHANGED
|
Binary files a/src/langgraphagenticai/ui/streamlitui/__pycache__/loadui.cpython-312.pyc and b/src/langgraphagenticai/ui/streamlitui/__pycache__/loadui.cpython-312.pyc differ
|
|
|
src/langgraphagenticai/ui/streamlitui/display_result.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 3 |
import logging
|
|
|
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
|
@@ -19,8 +20,7 @@ class DisplayResultStreamlit:
|
|
| 19 |
defaults = {
|
| 20 |
"waiting_for_feedback": False,
|
| 21 |
"blog_requirements_collected": False,
|
| 22 |
-
"
|
| 23 |
-
"draft_displayed": False,
|
| 24 |
"graph_state": None,
|
| 25 |
"current_session_id": None
|
| 26 |
}
|
|
@@ -36,8 +36,7 @@ class DisplayResultStreamlit:
|
|
| 36 |
store[session_id] = ChatMessageHistory()
|
| 37 |
# Reset display flags if session ID changes
|
| 38 |
if st.session_state.current_session_id != session_id:
|
| 39 |
-
st.session_state.
|
| 40 |
-
st.session_state.draft_displayed = False
|
| 41 |
st.session_state.current_session_id = session_id
|
| 42 |
return store[session_id]
|
| 43 |
|
|
@@ -56,11 +55,11 @@ class DisplayResultStreamlit:
|
|
| 56 |
self._handle_chatbot_input()
|
| 57 |
|
| 58 |
def _handle_blog_generation(self):
|
| 59 |
-
#
|
| 60 |
if not st.session_state.waiting_for_feedback and st.session_state.graph_state:
|
| 61 |
graph_state = st.session_state.graph_state
|
| 62 |
-
if graph_state.next and graph_state.next[0]
|
| 63 |
-
logger.info("
|
| 64 |
st.session_state.waiting_for_feedback = True
|
| 65 |
|
| 66 |
if st.session_state.waiting_for_feedback:
|
|
@@ -88,69 +87,71 @@ class DisplayResultStreamlit:
|
|
| 88 |
submit_button = st.form_submit_button("Submit Blog Requirements")
|
| 89 |
|
| 90 |
if submit_button:
|
| 91 |
-
if not all([topic, objective, target_audience, tone_style
|
| 92 |
st.error("Please fill in all required fields.")
|
| 93 |
return
|
| 94 |
-
user_message = f"Topic: {topic}
|
| 95 |
self.session_history.add_user_message(user_message)
|
| 96 |
with st.chat_message("user"):
|
| 97 |
st.markdown(user_message)
|
| 98 |
# Reset display flags for new blog generation
|
| 99 |
-
st.session_state.
|
| 100 |
-
st.session_state.draft_displayed = False
|
| 101 |
st.session_state.blog_requirements_collected = True
|
| 102 |
self._process_graph_stream(HumanMessage(content=user_message))
|
| 103 |
|
| 104 |
def _process_feedback(self):
|
| 105 |
latest_state = st.session_state.graph_state.values if st.session_state.graph_state else {}
|
| 106 |
-
|
| 107 |
-
|
|
|
|
| 108 |
with st.chat_message("assistant"):
|
| 109 |
-
|
| 110 |
-
st.markdown(
|
| 111 |
-
st.
|
| 112 |
-
st.session_state.outline_displayed = True
|
| 113 |
-
|
| 114 |
-
if "completed_sections" in latest_state and not st.session_state.get("draft_displayed", False):
|
| 115 |
-
with st.chat_message("assistant"):
|
| 116 |
-
draft_content = "\n\n".join(latest_state["completed_sections"])
|
| 117 |
-
st.markdown("### Generated Draft")
|
| 118 |
-
st.markdown(draft_content)
|
| 119 |
-
st.session_state.draft_displayed = True
|
| 120 |
|
|
|
|
| 121 |
current_node = st.session_state.graph_state.next[0] if st.session_state.graph_state.next else None
|
| 122 |
logger.info(f"Current node in _process_feedback: {current_node}")
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
with col1:
|
| 127 |
-
if st.button("Looks good", key="outline_approve"):
|
| 128 |
-
st.session_state.outline_feedback = "approved"
|
| 129 |
-
st.session_state.waiting_for_feedback = False
|
| 130 |
-
logger.info("Outline approved")
|
| 131 |
-
with col2:
|
| 132 |
-
if st.button("Add more details", key="outline_reject"):
|
| 133 |
-
st.session_state.outline_feedback = "add_more_details"
|
| 134 |
-
st.session_state.waiting_for_feedback = False
|
| 135 |
-
logger.info("Outline regeneration requested")
|
| 136 |
-
elif current_node == "draft_review":
|
| 137 |
-
st.write("Review the draft:")
|
| 138 |
col1, col2 = st.columns(2)
|
|
|
|
| 139 |
with col1:
|
| 140 |
-
if st.button("
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
| 145 |
with col2:
|
| 146 |
-
|
| 147 |
-
st.
|
| 148 |
-
|
| 149 |
-
st.
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
def _handle_chatbot_input(self):
|
| 156 |
user_message = st.chat_input("Enter your message:")
|
|
@@ -171,18 +172,19 @@ class DisplayResultStreamlit:
|
|
| 171 |
with st.chat_message("assistant"):
|
| 172 |
self._display_result(state)
|
| 173 |
self.session_history.add_ai_message(state["messages"][-1].content)
|
|
|
|
| 174 |
graph_state = self.graph.get_state(self.config)
|
| 175 |
logger.info(f"Graph state next: {graph_state.next}")
|
| 176 |
-
|
|
|
|
| 177 |
st.session_state.waiting_for_feedback = True
|
| 178 |
st.session_state.graph_state = graph_state
|
| 179 |
-
logger.info(
|
| 180 |
st.rerun() # Force UI update to ensure feedback buttons appear
|
| 181 |
break
|
| 182 |
elif not graph_state.next and self.usecase == "Blog Generation":
|
| 183 |
st.session_state.blog_requirements_collected = False
|
| 184 |
-
st.session_state.
|
| 185 |
-
st.session_state.draft_displayed = False # Reset for new blog
|
| 186 |
with st.chat_message("assistant"):
|
| 187 |
st.markdown("✅ Blog generation completed!")
|
| 188 |
if st.button("New Blog Generation"):
|
|
@@ -197,23 +199,24 @@ class DisplayResultStreamlit:
|
|
| 197 |
def _display_result(self, response):
|
| 198 |
logger.info(f"Display result response: {response}")
|
| 199 |
if self.usecase == "Blog Generation":
|
| 200 |
-
# Outline display moved to _process_feedback
|
| 201 |
messages = response.get("messages", [])
|
| 202 |
if messages:
|
| 203 |
content = messages[-1].content
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
st.markdown(
|
| 209 |
-
st.
|
| 210 |
-
|
| 211 |
st.markdown(content)
|
| 212 |
|
| 213 |
elif self.usecase == "Basic Chatbot":
|
|
|
|
| 214 |
st.markdown(response.get("messages", [{}])[-1].content)
|
| 215 |
|
| 216 |
elif self.usecase == "Chatbot with Tool":
|
|
|
|
| 217 |
content = response.get("messages", [{}])[-1].content
|
| 218 |
tool_output = response.get("tool_output", "")
|
| 219 |
if tool_output:
|
|
@@ -221,14 +224,11 @@ class DisplayResultStreamlit:
|
|
| 221 |
st.code(tool_output, language="text")
|
| 222 |
st.markdown(content)
|
| 223 |
|
| 224 |
-
elif self.usecase == "Coding Peer Review":
|
| 225 |
-
st.markdown("### Code Review Feedback")
|
| 226 |
-
st.markdown(response.get("review_output", "No review generated."))
|
| 227 |
-
if corrected_code := response.get("corrected_code", ""):
|
| 228 |
-
st.markdown("### Corrected Code")
|
| 229 |
-
st.code(corrected_code, language="python")
|
| 230 |
-
|
| 231 |
def _format_blog_content(self, content):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
sections = content.strip().split("\n\n")
|
| 233 |
formatted = "\n\n".join(
|
| 234 |
f"\n\n{s.strip()}" if s.startswith("#") else s.strip()
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from langchain_core.messages import HumanMessage, AIMessage
|
| 3 |
import logging
|
| 4 |
+
import json
|
| 5 |
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
|
|
|
|
| 20 |
defaults = {
|
| 21 |
"waiting_for_feedback": False,
|
| 22 |
"blog_requirements_collected": False,
|
| 23 |
+
"content_displayed": False,
|
|
|
|
| 24 |
"graph_state": None,
|
| 25 |
"current_session_id": None
|
| 26 |
}
|
|
|
|
| 36 |
store[session_id] = ChatMessageHistory()
|
| 37 |
# Reset display flags if session ID changes
|
| 38 |
if st.session_state.current_session_id != session_id:
|
| 39 |
+
st.session_state.content_displayed = False
|
|
|
|
| 40 |
st.session_state.current_session_id = session_id
|
| 41 |
return store[session_id]
|
| 42 |
|
|
|
|
| 55 |
self._handle_chatbot_input()
|
| 56 |
|
| 57 |
def _handle_blog_generation(self):
|
| 58 |
+
# Check if we're waiting for feedback
|
| 59 |
if not st.session_state.waiting_for_feedback and st.session_state.graph_state:
|
| 60 |
graph_state = st.session_state.graph_state
|
| 61 |
+
if graph_state.next and graph_state.next[0] == "feedback_collector":
|
| 62 |
+
logger.info("Setting waiting_for_feedback based on graph state")
|
| 63 |
st.session_state.waiting_for_feedback = True
|
| 64 |
|
| 65 |
if st.session_state.waiting_for_feedback:
|
|
|
|
| 87 |
submit_button = st.form_submit_button("Submit Blog Requirements")
|
| 88 |
|
| 89 |
if submit_button:
|
| 90 |
+
if not all([topic, objective, target_audience, tone_style]):
|
| 91 |
st.error("Please fill in all required fields.")
|
| 92 |
return
|
| 93 |
+
user_message = f"Topic: {topic}\nObjective: {objective}\nTarget Audience: {target_audience}\nTone & Style: {tone_style}\nWord Count: {word_count}\nStructure: {structure}"
|
| 94 |
self.session_history.add_user_message(user_message)
|
| 95 |
with st.chat_message("user"):
|
| 96 |
st.markdown(user_message)
|
| 97 |
# Reset display flags for new blog generation
|
| 98 |
+
st.session_state.content_displayed = False
|
|
|
|
| 99 |
st.session_state.blog_requirements_collected = True
|
| 100 |
self._process_graph_stream(HumanMessage(content=user_message))
|
| 101 |
|
| 102 |
def _process_feedback(self):
|
| 103 |
latest_state = st.session_state.graph_state.values if st.session_state.graph_state else {}
|
| 104 |
+
|
| 105 |
+
# Display the content if it hasn't been displayed yet
|
| 106 |
+
if "blog_content" in latest_state and not st.session_state.get("content_displayed", False):
|
| 107 |
with st.chat_message("assistant"):
|
| 108 |
+
st.markdown("### Generated Blog Content")
|
| 109 |
+
st.markdown(self._format_blog_content(latest_state["blog_content"]))
|
| 110 |
+
st.session_state.content_displayed = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
# Check if we're at the feedback collection node
|
| 113 |
current_node = st.session_state.graph_state.next[0] if st.session_state.graph_state.next else None
|
| 114 |
logger.info(f"Current node in _process_feedback: {current_node}")
|
| 115 |
+
|
| 116 |
+
if current_node == "feedback_collector":
|
| 117 |
+
st.write("### Review the generated content:")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
col1, col2 = st.columns(2)
|
| 119 |
+
|
| 120 |
with col1:
|
| 121 |
+
if st.button("Approve", key="content_approve"):
|
| 122 |
+
feedback = {
|
| 123 |
+
"approved": True,
|
| 124 |
+
"comments": "Content approved."
|
| 125 |
+
}
|
| 126 |
+
self._submit_feedback(feedback)
|
| 127 |
+
|
| 128 |
with col2:
|
| 129 |
+
with st.expander("Request Revisions"):
|
| 130 |
+
comments = st.text_area("Provide revision comments:",
|
| 131 |
+
placeholder="Please explain what changes you would like to see.")
|
| 132 |
+
if st.button("Submit Revisions"):
|
| 133 |
+
if not comments:
|
| 134 |
+
st.error("Please provide revision comments.")
|
| 135 |
+
else:
|
| 136 |
+
feedback = {
|
| 137 |
+
"approved": False,
|
| 138 |
+
"comments": comments
|
| 139 |
+
}
|
| 140 |
+
self._submit_feedback(feedback)
|
| 141 |
+
|
| 142 |
+
def _submit_feedback(self, feedback):
|
| 143 |
+
"""Submit feedback to the graph and continue processing."""
|
| 144 |
+
try:
|
| 145 |
+
# Convert feedback to the expected ReviewFeedback format
|
| 146 |
+
feedback_json = json.dumps(feedback)
|
| 147 |
+
st.session_state.waiting_for_feedback = False
|
| 148 |
+
st.session_state.content_displayed = False
|
| 149 |
+
# Continue processing with the feedback
|
| 150 |
+
self._process_graph_stream(HumanMessage(content=feedback_json))
|
| 151 |
+
logger.info(f"Feedback submitted: {feedback}")
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.error(f"Error submitting feedback: {e}")
|
| 154 |
+
st.error(f"Error submitting feedback: {e}")
|
| 155 |
|
| 156 |
def _handle_chatbot_input(self):
|
| 157 |
user_message = st.chat_input("Enter your message:")
|
|
|
|
| 172 |
with st.chat_message("assistant"):
|
| 173 |
self._display_result(state)
|
| 174 |
self.session_history.add_ai_message(state["messages"][-1].content)
|
| 175 |
+
|
| 176 |
graph_state = self.graph.get_state(self.config)
|
| 177 |
logger.info(f"Graph state next: {graph_state.next}")
|
| 178 |
+
|
| 179 |
+
if graph_state.next and graph_state.next[0] == "feedback_collector":
|
| 180 |
st.session_state.waiting_for_feedback = True
|
| 181 |
st.session_state.graph_state = graph_state
|
| 182 |
+
logger.info("Paused for feedback collection")
|
| 183 |
st.rerun() # Force UI update to ensure feedback buttons appear
|
| 184 |
break
|
| 185 |
elif not graph_state.next and self.usecase == "Blog Generation":
|
| 186 |
st.session_state.blog_requirements_collected = False
|
| 187 |
+
st.session_state.content_displayed = False # Reset for new blog
|
|
|
|
| 188 |
with st.chat_message("assistant"):
|
| 189 |
st.markdown("✅ Blog generation completed!")
|
| 190 |
if st.button("New Blog Generation"):
|
|
|
|
| 199 |
def _display_result(self, response):
|
| 200 |
logger.info(f"Display result response: {response}")
|
| 201 |
if self.usecase == "Blog Generation":
|
|
|
|
| 202 |
messages = response.get("messages", [])
|
| 203 |
if messages:
|
| 204 |
content = messages[-1].content
|
| 205 |
+
blog_content = response.get("blog_content", "")
|
| 206 |
+
|
| 207 |
+
if blog_content and not st.session_state.content_displayed:
|
| 208 |
+
st.markdown("### Generated Blog Content")
|
| 209 |
+
st.markdown(self._format_blog_content(blog_content))
|
| 210 |
+
st.session_state.content_displayed = True
|
| 211 |
+
else:
|
| 212 |
st.markdown(content)
|
| 213 |
|
| 214 |
elif self.usecase == "Basic Chatbot":
|
| 215 |
+
# Kept exactly as in original code
|
| 216 |
st.markdown(response.get("messages", [{}])[-1].content)
|
| 217 |
|
| 218 |
elif self.usecase == "Chatbot with Tool":
|
| 219 |
+
# Kept exactly as in original code
|
| 220 |
content = response.get("messages", [{}])[-1].content
|
| 221 |
tool_output = response.get("tool_output", "")
|
| 222 |
if tool_output:
|
|
|
|
| 224 |
st.code(tool_output, language="text")
|
| 225 |
st.markdown(content)
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
def _format_blog_content(self, content):
|
| 228 |
+
"""Format blog content for better display in Streamlit."""
|
| 229 |
+
if not content:
|
| 230 |
+
return ""
|
| 231 |
+
|
| 232 |
sections = content.strip().split("\n\n")
|
| 233 |
formatted = "\n\n".join(
|
| 234 |
f"\n\n{s.strip()}" if s.startswith("#") else s.strip()
|