DrishtiSharma commited on
Commit
b13c344
·
verified ·
1 Parent(s): d619f18

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +176 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,178 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
2
+ import re
3
+ import json
4
  import streamlit as st
5
+ from pathlib import Path
6
+ from typing import List, Annotated, Any
7
+ import operator
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+ from pydantic import BaseModel
11
+ from langchain.embeddings.cohere import CohereEmbeddings
12
+ from langchain_cohere import ChatCohere
13
+ from langchain.document_loaders import DirectoryLoader, TextLoader
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain.vectorstores import Chroma
16
+ import cohere
17
+ from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, AIMessage
18
+ from langgraph.graph import StateGraph, START, END, add_messages
19
+ from langgraph.constants import Send
20
+ from langgraph.checkpoint.memory import MemorySaver
21
+
22
+ load_dotenv()
23
+ os.environ["user_agent"] = "langchain-app/1.0"
24
+ COHERE_API_KEY = os.environ["COHERE_API_KEY"]
25
+ co = cohere.Client(COHERE_API_KEY)
26
+ persist_dir = "./chroma_store"
27
+
28
+ def prepare_vectorstore():
29
+ loader = DirectoryLoader("./documents", glob="**/*.txt", loader_cls=TextLoader)
30
+ documents = loader.load()
31
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
32
+ docs = splitter.split_documents(documents)
33
+ embedding = CohereEmbeddings(
34
+ model="embed-multilingual-light-v3.0",
35
+ user_agent="langchain-app/1.0",
36
+ cohere_api_key=""
37
+ )
38
+ vectorstore = Chroma.from_documents(
39
+ documents=tqdm(docs, desc="Embedding"),
40
+ embedding=embedding,
41
+ persist_directory=persist_dir
42
+ )
43
+ vectorstore.persist()
44
+ return vectorstore
45
+
46
+ if not os.path.exists(persist_dir):
47
+ prepare_vectorstore()
48
+
49
+ class State(BaseModel):
50
+ state: List[str] = []
51
+ messages: Annotated[list[AnyMessage], add_messages]
52
+ topic: List[str] = []
53
+ context: List[str] = []
54
+ sub_topic_list: List[str] = []
55
+ sub_topics: Annotated[list[AnyMessage], add_messages]
56
+ stories: Annotated[list[AnyMessage], add_messages]
57
+ stories_lst: Annotated[list, operator.add]
58
+
59
+ class StoryState(BaseModel):
60
+ retrieved_docs: List[Any] = []
61
+ stories: Annotated[list[AnyMessage], add_messages]
62
+ reranked_docs: List[str] = []
63
+ story_topic: str = ""
64
+ stories_lst: Annotated[list, operator.add]
65
+
66
+ def extract_topics(messages):
67
+ topics = []
68
+ for message in messages:
69
+ topics.extend(re.findall(r'- \*\*(.*?)\*\*', message.content))
70
+ return topics
71
+
72
+ embedding_llm = CohereEmbeddings(
73
+ model="embed-multilingual-light-v3.0",
74
+ user_agent="langchain-app/1.0",
75
+ cohere_api_key=COHERE_API_KEY
76
+ )
77
+ llm = ChatCohere(
78
+ api_version="2024-02-15-preview",
79
+ temperature=0.7,
80
+ model="command-r-plus-08-2024",
81
+ cohere_api_key=COHERE_API_KEY
82
+ )
83
+
84
+ beginner_topic_sys_msg = SystemMessage(content="Suppose you are a middle grader who wants to learn constantly about new topics to get a good score in exams.")
85
+ middle_topic_sys_msg = SystemMessage(content="Suppose you are a college student who wants to learn constantly about new topics to get a good score in exams.")
86
+ advanced_topic_sys_msg = SystemMessage(content="Suppose you are a teacher who wants to learn constantly about new topics to teach your students.")
87
+
88
+ def retrieve_node(state):
89
+ topic = state.story_topic
90
+ query = f"information about {topic}"
91
+ retriever = Chroma(persist_directory=persist_dir, embedding_function=embedding_llm).as_retriever(search_kwargs={"k": 20})
92
+ docs = retriever.get_relevant_documents(query)
93
+ return {"retrieved_docs": docs, "question": query}
94
+
95
+ def rerank_node(state):
96
+ topic = state.story_topic
97
+ query = f"Rerank documents based on how good they explain the topic {topic}"
98
+ docs = state.retrieved_docs
99
+ texts = [doc.page_content for doc in docs]
100
+ rerank_results = co.rerank(query=query, documents=texts, top_n=5, model="rerank-v3.5")
101
+ top_docs = [texts[result.index] for result in rerank_results.results]
102
+ return {"reranked_docs": top_docs, "question": query}
103
+
104
+ def generate_story_node(state):
105
+ context = "\n\n".join(state.reranked_docs)
106
+ topic = state.story_topic
107
+ system_message = """
108
+ Suppose You're a Amazing story writter and scientific thinker.
109
+ You have written hundreds of story books explaining scientific topic in childlike manner that exven amiddle grader could understand.
110
+ You add a subtle humor to your story to make it more life like.
111
+ """
112
+ prompt = f"""
113
+ Now Use the following context to generate a simple engaging story that explains {topic} in such a way an middle schooler can understand the {topic}.\n
114
+ Context:\n{context}\n\n
115
+ Story:
116
+ """
117
+ response = llm.invoke([SystemMessage(system_message), HumanMessage(prompt)])
118
+ return {"stories": response}
119
+
120
+ def beginner_topic(state: State):
121
+ prompt = f"What are the beginner-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}?"
122
+ sub_topics = [llm.invoke([beginner_topic_sys_msg] + [prompt])]
123
+ return {"message": sub_topics[0], "sub_topics": sub_topics[0]}
124
+
125
+ def middle_topic(state: State):
126
+ prompt = f"What are the middle-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}? Don't include the topics below:\n\n{(state.sub_topics)}"
127
+ sub_topics = [llm.invoke([middle_topic_sys_msg] + [prompt])]
128
+ return {"message": sub_topics, "sub_topics": sub_topics}
129
+
130
+ def advanced_topic(state: State):
131
+ prompt = f"What are the advanced-level topics you can learn about {', '.join(state.topic)} in {', '.join(state.context)}? Don't include the topics below:\n\n{(state.sub_topics)}"
132
+ sub_topics = [llm.invoke([advanced_topic_sys_msg] + [prompt])]
133
+ return {"message": sub_topics, "sub_topics": sub_topics}
134
+
135
+ def topic_extractor(state: State):
136
+ return {"sub_topic_list": extract_topics(state.sub_topics)}
137
+
138
+ def dynamic_topic_edges(state: State):
139
+ return [Send("story_generator", {"story_topic": topic}) for topic in state.sub_topic_list]
140
+
141
+ story_builder = StateGraph(StoryState)
142
+ story_builder.add_node("Retrieve", retrieve_node)
143
+ story_builder.add_node("Rerank", rerank_node)
144
+ story_builder.add_node("Generate", generate_story_node)
145
+ story_builder.set_entry_point("Retrieve")
146
+ story_builder.add_edge("Retrieve", "Rerank")
147
+ story_builder.add_edge("Rerank", "Generate")
148
+ story_builder.set_finish_point("Generate")
149
+
150
+ story_graph = story_builder.compile()
151
+
152
+ main_builder = StateGraph(State)
153
+ main_builder.add_node("beginner_topic", beginner_topic)
154
+ main_builder.add_node("middle_topic", middle_topic)
155
+ main_builder.add_node("advanced_topic", advanced_topic)
156
+ main_builder.add_node("topic_extractor", topic_extractor)
157
+ main_builder.add_node("story_generator", story_graph)
158
+ main_builder.add_edge(START, "beginner_topic")
159
+ main_builder.add_edge("beginner_topic", "middle_topic")
160
+ main_builder.add_edge("middle_topic", "advanced_topic")
161
+ main_builder.add_edge("advanced_topic", "topic_extractor")
162
+ main_builder.add_conditional_edges("topic_extractor", dynamic_topic_edges, ["story_generator"])
163
+ main_builder.add_edge("story_generator", END)
164
+
165
+ memory = MemorySaver()
166
+ react_graph = main_builder.compile(checkpointer=memory, interrupt_after=["topic_extractor"])
167
+
168
+ st.title("LangGraph Topic Story Generator")
169
+ topic = st.text_input("Enter a topic", "Human Evolution")
170
+ context = st.text_input("Enter a context", "Science")
171
 
172
+ if st.button("Generate Stories"):
173
+ thread = {"configurable": {"thread_id": "1"}}
174
+ react_graph.invoke({"topic": [topic], "context": [context]}, thread)
175
+ react_graph.update_state(thread, {"sub_topic_list": ['Early Hominins', 'Fossil Evidence', "Darwin's Theory of Evolution"]})
176
+ result = react_graph.invoke(None, thread, stream_mode="values")
177
+ for story in result["stories"]:
178
+ st.markdown(story.content)