Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -27,7 +27,29 @@ import os
|
|
| 27 |
import tempfile
|
| 28 |
from datetime import datetime
|
| 29 |
import pytz
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
class DocumentRAG:
|
| 33 |
def __init__(self):
|
|
@@ -248,10 +270,82 @@ class DocumentRAG:
|
|
| 248 |
except Exception as e:
|
| 249 |
return history + [("System", f"Error: {str(e)}")]
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
# Initialize RAG system in session state
|
| 252 |
if "rag_system" not in st.session_state:
|
| 253 |
st.session_state.rag_system = DocumentRAG()
|
| 254 |
|
|
|
|
|
|
|
|
|
|
| 255 |
# Sidebar
|
| 256 |
with st.sidebar:
|
| 257 |
st.title("About")
|
|
@@ -337,7 +431,31 @@ if st.session_state.rag_system.qa_chain:
|
|
| 337 |
else:
|
| 338 |
st.info("Please process documents first to enable Q&A.")
|
| 339 |
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
st.subheader("Step 4: Generate Podcast")
|
| 342 |
st.write("Select Podcast Language:")
|
| 343 |
podcast_language_options = ["English", "Hindi", "Spanish", "French", "German", "Chinese", "Japanese"]
|
|
@@ -348,6 +466,7 @@ podcast_language = st.radio(
|
|
| 348 |
key="podcast_language"
|
| 349 |
)
|
| 350 |
|
|
|
|
| 351 |
if st.session_state.rag_system.document_summary:
|
| 352 |
if st.button("Generate Podcast"):
|
| 353 |
with st.spinner("Generating podcast, please wait..."):
|
|
|
|
| 27 |
import tempfile
|
| 28 |
from datetime import datetime
|
| 29 |
import pytz
|
| 30 |
+
from langgraph.graph import StateGraph, START, END, Send, add_messages
|
| 31 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 32 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage
|
| 33 |
+
from pydantic import BaseModel
|
| 34 |
+
from typing import List, Annotated, Any
|
| 35 |
+
import re, operator
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class MultiAgentState(BaseModel):
|
| 39 |
+
state: List[str] = []
|
| 40 |
+
messages: Annotated[list[AnyMessage], add_messages]
|
| 41 |
+
topic: List[str] = []
|
| 42 |
+
context: List[str] = []
|
| 43 |
+
sub_topic_list: List[str] = []
|
| 44 |
+
sub_topics: Annotated[list[AnyMessage], add_messages]
|
| 45 |
+
stories: Annotated[list[AnyMessage], add_messages]
|
| 46 |
+
stories_lst: Annotated[list, operator.add]
|
| 47 |
+
|
| 48 |
+
class StoryState(BaseModel):
|
| 49 |
+
retrieved_docs: List[Any] = []
|
| 50 |
+
stories: Annotated[list[AnyMessage], add_messages]
|
| 51 |
+
story_topic: str = ""
|
| 52 |
+
stories_lst: Annotated[list, operator.add]
|
| 53 |
|
| 54 |
class DocumentRAG:
|
| 55 |
def __init__(self):
|
|
|
|
| 270 |
except Exception as e:
|
| 271 |
return history + [("System", f"Error: {str(e)}")]
|
| 272 |
|
| 273 |
+
def extract_subtopics(self, messages):
|
| 274 |
+
text = "\n".join([msg.content for msg in messages])
|
| 275 |
+
return re.findall(r'- \*\*(.*?)\*\*', text)
|
| 276 |
+
|
| 277 |
+
def beginner_topic(self, state: MultiAgentState):
|
| 278 |
+
prompt = f"What are the beginner-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}?"
|
| 279 |
+
msg = self.llm.invoke([SystemMessage("Suppose you're a middle grader..."), HumanMessage(prompt)])
|
| 280 |
+
return {"message": msg, "sub_topics": msg}
|
| 281 |
+
|
| 282 |
+
def middle_topic(self, state: MultiAgentState):
|
| 283 |
+
prompt = f"What are the middle-level topics for {', '.join(state.topic)} in {', '.join(state.context)}? Avoid previous."
|
| 284 |
+
msg = self.llm.invoke([SystemMessage("Suppose you're a college student..."), HumanMessage(prompt)])
|
| 285 |
+
return {"message": msg, "sub_topics": msg}
|
| 286 |
+
|
| 287 |
+
def advanced_topic(self, state: MultiAgentState):
|
| 288 |
+
prompt = f"What are the advanced-level topics for {', '.join(state.topic)} in {', '.join(state.context)}? Avoid previous."
|
| 289 |
+
msg = self.llm.invoke([SystemMessage("Suppose you're a teacher..."), HumanMessage(prompt)])
|
| 290 |
+
return {"message": msg, "sub_topics": msg}
|
| 291 |
+
|
| 292 |
+
def topic_extractor(self, state: MultiAgentState):
|
| 293 |
+
return {"sub_topic_list": self.extract_subtopics(state.sub_topics)}
|
| 294 |
+
|
| 295 |
+
def retrieve_docs(self, state: StoryState):
|
| 296 |
+
retriever = self.document_store.as_retriever(search_kwargs={"k": 20})
|
| 297 |
+
docs = retriever.get_relevant_documents(f"information about {state.story_topic}")
|
| 298 |
+
return {"retrieved_docs": docs}
|
| 299 |
+
|
| 300 |
+
def generate_story(self, state: StoryState):
|
| 301 |
+
context = "\n\n".join([doc.page_content for doc in state.retrieved_docs[:5]])
|
| 302 |
+
prompt = f"""You're a witty science storyteller. Create a short, child-friendly story that explains **{state.story_topic}** based on this:\n\n{context}"""
|
| 303 |
+
msg = self.llm.invoke([SystemMessage("Use humor. Be clear."), HumanMessage(prompt)])
|
| 304 |
+
return {"stories": msg}
|
| 305 |
+
|
| 306 |
+
def run_multiagent_storygraph(self, topic: str, context: str):
|
| 307 |
+
self.llm = ChatOpenAI(model_name="gpt-4", temperature=0.7, api_key=self.api_key)
|
| 308 |
+
|
| 309 |
+
# Story subgraph
|
| 310 |
+
story_graph = StateGraph(StoryState)
|
| 311 |
+
story_graph.add_node("Retrieve", self.retrieve_docs)
|
| 312 |
+
story_graph.add_node("Generate", self.generate_story)
|
| 313 |
+
story_graph.set_entry_point("Retrieve")
|
| 314 |
+
story_graph.add_edge("Retrieve", "Generate")
|
| 315 |
+
story_graph.set_finish_point("Generate")
|
| 316 |
+
story_subgraph = story_graph.compile()
|
| 317 |
+
|
| 318 |
+
# Main graph
|
| 319 |
+
graph = StateGraph(MultiAgentState)
|
| 320 |
+
graph.add_node("beginner_topic", self.beginner_topic)
|
| 321 |
+
graph.add_node("middle_topic", self.middle_topic)
|
| 322 |
+
graph.add_node("advanced_topic", self.advanced_topic)
|
| 323 |
+
graph.add_node("topic_extractor", self.topic_extractor)
|
| 324 |
+
graph.add_node("story_generator", story_subgraph)
|
| 325 |
+
|
| 326 |
+
graph.add_edge(START, "beginner_topic")
|
| 327 |
+
graph.add_edge("beginner_topic", "middle_topic")
|
| 328 |
+
graph.add_edge("middle_topic", "advanced_topic")
|
| 329 |
+
graph.add_edge("advanced_topic", "topic_extractor")
|
| 330 |
+
graph.add_conditional_edges("topic_extractor",
|
| 331 |
+
lambda state: [Send("story_generator", {"story_topic": t}) for t in state.sub_topic_list],
|
| 332 |
+
["story_generator"])
|
| 333 |
+
graph.add_edge("story_generator", END)
|
| 334 |
+
|
| 335 |
+
compiled = graph.compile(checkpointer=MemorySaver())
|
| 336 |
+
result = compiled.invoke({"topic": [topic], "context": [context]})
|
| 337 |
+
return result
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
|
| 342 |
# Initialize RAG system in session state
|
| 343 |
if "rag_system" not in st.session_state:
|
| 344 |
st.session_state.rag_system = DocumentRAG()
|
| 345 |
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
|
| 349 |
# Sidebar
|
| 350 |
with st.sidebar:
|
| 351 |
st.title("About")
|
|
|
|
| 431 |
else:
|
| 432 |
st.info("Please process documents first to enable Q&A.")
|
| 433 |
|
| 434 |
+
|
| 435 |
+
# Step 4: Multi-Agent Story Explorer
|
| 436 |
+
st.subheader("Step 5: Explore Subtopics via Multi-Agent Graph")
|
| 437 |
+
story_topic = st.text_input("Enter main topic:", value="Machine Learning")
|
| 438 |
+
story_context = st.text_input("Enter learning context:", value="Education")
|
| 439 |
+
|
| 440 |
+
if st.button("Run Story Graph"):
|
| 441 |
+
with st.spinner("Generating subtopics and stories..."):
|
| 442 |
+
result = st.session_state.rag_system.run_multiagent_storygraph(topic=story_topic, context=story_context)
|
| 443 |
+
|
| 444 |
+
subtopics = result.get("sub_topic_list", [])
|
| 445 |
+
st.markdown("### ๐ง Extracted Subtopics")
|
| 446 |
+
for sub in subtopics:
|
| 447 |
+
st.markdown(f"- {sub}")
|
| 448 |
+
|
| 449 |
+
stories = result.get("stories", [])
|
| 450 |
+
if stories:
|
| 451 |
+
st.markdown("### ๐ Generated Stories")
|
| 452 |
+
for i, story in enumerate(stories):
|
| 453 |
+
st.markdown(f"**Story {i+1}:**")
|
| 454 |
+
st.markdown(story.content)
|
| 455 |
+
else:
|
| 456 |
+
st.warning("No stories were generated.")
|
| 457 |
+
|
| 458 |
+
# Step 5: Generate Podcast
|
| 459 |
st.subheader("Step 4: Generate Podcast")
|
| 460 |
st.write("Select Podcast Language:")
|
| 461 |
podcast_language_options = ["English", "Hindi", "Spanish", "French", "German", "Chinese", "Japanese"]
|
|
|
|
| 466 |
key="podcast_language"
|
| 467 |
)
|
| 468 |
|
| 469 |
+
|
| 470 |
if st.session_state.rag_system.document_summary:
|
| 471 |
if st.button("Generate Podcast"):
|
| 472 |
with st.spinner("Generating podcast, please wait..."):
|