Nobody4591 commited on
Commit
1be17a6
·
verified ·
1 Parent(s): 9403bc6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -0
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from llama_index.core import (
4
+ Document,
5
+ SummaryIndex,
6
+ load_index_from_storage,
7
+ # TODO update this in docs
8
+ VectorStoreIndex,
9
+ StorageContext,
10
+ )
11
+ from llama_index.llms.openai import OpenAI
12
+ from llama_index.core import Settings
13
+
14
+ from llama_index.core.prompts.prompts import QuestionAnswerPrompt, RefinePrompt
15
+
16
+
17
+ # Text QA templates
18
+ DEFAULT_TEXT_QA_PROMPT_TMPL = (
19
+ "Context information is below. \n"
20
+ "---------------------\n"
21
+ "{context_str}"
22
+ "\n---------------------\n"
23
+ "Given the context information, directly answer the following question "
24
+ "(if you don't know the answer, use the best of your knowledge): {query_str}\n"
25
+ )
26
+ TEXT_QA_TEMPLATE = QuestionAnswerPrompt(DEFAULT_TEXT_QA_PROMPT_TMPL)
27
+
28
+ # Refine templates
29
+ DEFAULT_REFINE_PROMPT_TMPL = (
30
+ "The original question is as follows: {query_str}\n"
31
+ "We have provided an existing answer: {existing_answer}\n"
32
+ "We have the opportunity to refine the existing answer "
33
+ "(only if needed) with some more context below.\n"
34
+ "------------\n"
35
+ "{context_msg}\n"
36
+ "------------\n"
37
+ "Given the new context and using the best of your knowledge, improve the existing answer. "
38
+ "If you can't improve the existing answer, just repeat it again. "
39
+ "Do not include un-needed or un-helpful information that is shown in the new context. "
40
+ "Do not mention that you've read the above context."
41
+ )
42
+ DEFAULT_REFINE_PROMPT = RefinePrompt(DEFAULT_REFINE_PROMPT_TMPL)
43
+
44
+
45
+ def get_llm(
46
+ llm_name,
47
+ model_temperature,
48
+ api_key,
49
+ max_tokens=256,
50
+ ):
51
+ os.environ["OPENAI_API_KEY"] = api_key
52
+ return OpenAI(
53
+ temperature=model_temperature,
54
+ model=llm_name,
55
+ max_tokens=max_tokens,
56
+ )
57
+
58
+
59
+ def extract_terms(
60
+ documents,
61
+ term_extract_str,
62
+ llm_name,
63
+ model_temperature,
64
+ api_key,
65
+ ):
66
+ llm = get_llm(
67
+ llm_name,
68
+ model_temperature,
69
+ api_key,
70
+ max_tokens=1024,
71
+ )
72
+
73
+ temp_index = SummaryIndex.from_documents(
74
+ documents,
75
+ )
76
+ query_engine = temp_index.as_query_engine(
77
+ response_mode="tree_summarize",
78
+ llm=llm,
79
+ )
80
+ terms_definitions = str(query_engine.query(term_extract_str))
81
+ terms_definitions = [
82
+ x
83
+ for x in terms_definitions.split("\n")
84
+ if x and "Term:" in x and "Definition:" in x
85
+ ]
86
+ # parse the text into a dict
87
+ terms_to_definition = {
88
+ x.split("Definition:")[0]
89
+ .split("Term:")[-1]
90
+ .strip(): x.split("Definition:")[-1]
91
+ .strip()
92
+ for x in terms_definitions
93
+ }
94
+ return terms_to_definition
95
+
96
+
97
+ DEFAULT_TERMS = {
98
+ "New York City": "The most populous city in the United States, located at the southern tip of New York State, and the largest metropolitan area in the U.S. by both population and urban area.",
99
+ "boroughs": "Five administrative divisions of New York City, each coextensive with a respective county of the state of New York: Brooklyn, Queens, Manhattan, The Bronx, and Staten Island.",
100
+ "metropolitan statistical area": "A geographical region with a relatively high population density at its core and close economic ties throughout the area.",
101
+ "combined statistical area": "A combination of adjacent metropolitan and micropolitan statistical areas in the United States and Puerto Rico that can demonstrate economic or social linkage.",
102
+ "megacities": "A city with a population of over 10 million people.",
103
+ "United Nations": "An intergovernmental organization that aims to maintain international peace and security, develop friendly relations among nations, achieve international cooperation, and be a center for harmonizing the actions of nations.",
104
+ "Pulitzer Prizes": "A series of annual awards for achievements in journalism, literature, and musical composition in the United States.",
105
+ "Times Square": "A major commercial and tourist destination in Manhattan, New York City.",
106
+ "New Netherland": "A Dutch colony in North America that existed from 1614 until 1664.",
107
+ "Dutch West India Company": "A Dutch trading company that operated as a monopoly in New Netherland from 1621 until 1639-1640.",
108
+ "patroon system": "A system instituted by the Dutch to attract settlers to New Netherland, whereby wealthy Dutchmen who brought 50 colonists would be awarded land and local political autonomy.",
109
+ "Peter Stuyvesant": "The last Director-General of New Netherland, who served from 1647 until 1664.",
110
+ "Treaty of Breda": "A treaty signed in 1667 between the Dutch and English that resulted in the Dutch keeping Suriname and the English keeping New Amsterdam (which was renamed New York).",
111
+ "African Burying Ground": "A cemetery discovered in Foley Square in the 1990s that included 10,000 to 20,000 graves of colonial-era Africans, some enslaved and some free.",
112
+ "Stamp Act Congress": "A meeting held in New York in 1765 in response to the Stamp Act, which imposed taxes on printed materials in the American colonies.",
113
+ "Battle of Long Island": "The largest battle of the American Revolutionary War, fought on August 27, 1776, in Brooklyn, New York City.",
114
+ "New York Police Department": "The police force of New York City.",
115
+ "Irish immigrants": "People who immigrated to the United States from Ireland.",
116
+ "lynched": "To kill someone, especially by hanging, without a legal trial.",
117
+ "civil unrest": "A situation in which people in a country are angry and likely to protest or fight.",
118
+ "megacity": "A very large city, typically one with a population of over ten million people.",
119
+ "World Trade Center": "A complex of buildings in Lower Manhattan, New York City, that were destroyed in the September 11 attacks.",
120
+ "COVID-19": "A highly infectious respiratory illness caused by the SARS-CoV-2 virus.",
121
+ "monkeypox outbreak": "An outbreak of a viral disease similar to smallpox, which occurred in the LGBT community in New York City in 2022.",
122
+ "Hudson River": "A river in the northeastern United States, flowing from the Adirondack Mountains in New York into the Atlantic Ocean.",
123
+ "estuary": "A partly enclosed coastal body of brackish water with one or more rivers or streams flowing into it, and with a free connection to the open sea.",
124
+ "East River": "A tidal strait in New York City.",
125
+ "Five Boroughs": "Refers to the five counties that make up New York City: Bronx, Brooklyn, Manhattan, Queens, and Staten Island.",
126
+ "Staten Island": "The most suburban of the five boroughs, located southwest of Manhattan and connected to it by the free Staten Island Ferry.",
127
+ "Todt Hill": "The highest point on the eastern seaboard south of Maine, located on Staten Island.",
128
+ "Manhattan": "The geographically smallest and most densely populated borough of New York City, known for its skyscrapers, Central Park, and cultural, administrative, and financial centers.",
129
+ "Brooklyn": "The most populous borough of New York City, located on the western tip of Long Island and known for its cultural diversity, independent art scene, and distinctive neighborhoods.",
130
+ "Queens": "The largest borough of New York City, located on Long Island north and east of Brooklyn, and known for its ethnic diversity, commercial and residential prominence, and hosting of the annual U.S. Open tennis tournament.",
131
+ "The Bronx": "The northernmost borough of New York",
132
+ }
133
+
134
+ if "all_terms" not in st.session_state:
135
+ st.session_state["all_terms"] = DEFAULT_TERMS
136
+
137
+
138
+ def insert_terms(terms_to_definition):
139
+ for term, definition in terms_to_definition.items():
140
+ doc = Document(text=f"Term: {term}\nDefinition: {definition}")
141
+ st.session_state["llama_index"].insert(doc)
142
+
143
+
144
+ @st.cache_resource
145
+ def initialize_index(llm_name, model_temperature, api_key):
146
+ """Create the VectorStoreIndex object."""
147
+ # TODO update this thing in doc
148
+ Settings.llm = get_llm(llm_name, model_temperature, api_key)
149
+
150
+ # create a vector store index for each folder
151
+ try:
152
+ index = load_index_from_storage(
153
+ StorageContext.from_defaults(persist_dir="./initial_index")
154
+ )
155
+ except Exception as e:
156
+ docs = [
157
+ Document(text=key + " : " + value) for key, value in DEFAULT_TERMS.items()
158
+ ]
159
+ index = VectorStoreIndex.from_documents(docs)
160
+ index.storage_context.persist(persist_dir="./initial_index")
161
+ # TODO update this in docs
162
+ return index
163
+
164
+
165
+ DEFAULT_TERM_STR = (
166
+ "Make a list of terms and definitions that are defined in the context, "
167
+ "with one pair on each line. "
168
+ "If a term is missing it's definition, use your best judgment. "
169
+ "Write each line as as follows:\nTerm: <term> Definition: <definition>"
170
+ )
171
+
172
+ st.title("🦙 Llama Index Term Extractor 🦙")
173
+
174
+ setup_tab, terms_tab, upload_tab, query_tab = st.tabs(
175
+ ["Setup", "All Terms", "Upload/Extract Terms", "Query Terms"]
176
+ )
177
+
178
+ with setup_tab:
179
+ st.subheader("LLM Setup")
180
+ api_key = st.text_input("Enter your OpenAI API key here", type="password")
181
+ llm_name = st.selectbox(
182
+ "Choose an LLM", ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"]
183
+ )
184
+ model_temperature = st.slider(
185
+ "Model Temperature", min_value=0.0, max_value=1.0, value=0.0, step=0.1
186
+ )
187
+ term_extract_str = st.text_area(
188
+ "Enter your term extraction prompt here",
189
+ value=DEFAULT_TERM_STR,
190
+ )
191
+ with upload_tab:
192
+ st.subheader("Extract and Query Definitions")
193
+ if st.button("Initialize Index and Reset Terms"):
194
+ st.session_state["llama_index"] = initialize_index(
195
+ llm_name, model_temperature, api_key
196
+ )
197
+ st.session_state["all_terms"] = {}
198
+ if "llama_index" in st.session_state:
199
+ st.markdown(
200
+ "Either upload an image/screenshot of a document, or enter the text manually."
201
+ )
202
+ document_text = st.text_area("Or enter raw text")
203
+ # TODO remove uploaded_file in docs and update the text
204
+ if st.button("Extract Terms and Definitions") and document_text:
205
+ st.session_state["terms"] = {}
206
+ terms_docs = {}
207
+ with st.spinner("Extracting..."):
208
+ terms_docs.update(
209
+ extract_terms(
210
+ [Document(text=document_text)],
211
+ term_extract_str,
212
+ llm_name,
213
+ model_temperature,
214
+ api_key,
215
+ )
216
+ )
217
+ st.session_state["terms"].update(terms_docs)
218
+
219
+ if "terms" in st.session_state and st.session_state["terms"]:
220
+ st.markdown("Extracted terms")
221
+ st.json(st.session_state["terms"])
222
+
223
+ if st.button("Insert terms?"):
224
+ with st.spinner("Inserting terms"):
225
+ insert_terms(st.session_state["terms"])
226
+ st.session_state["all_terms"].update(st.session_state["terms"])
227
+ st.session_state["terms"] = {}
228
+ st.experimental_rerun()
229
+
230
+ with terms_tab:
231
+ with terms_tab:
232
+ st.subheader("Current Extracted Terms and Definitions")
233
+ st.json(st.session_state["all_terms"])
234
+
235
+ with query_tab:
236
+ st.subheader("Query for Terms/Definitions!")
237
+ st.markdown(
238
+ (
239
+ "The LLM will attempt to answer your query, and augment it's answers using the terms/definitions you've inserted. "
240
+ "If a term is not in the index, it will answer using it's internal knowledge."
241
+ )
242
+ )
243
+ if st.button("Initialize Index and Reset Terms", key="init_index_2"):
244
+ st.session_state["llama_index"] = initialize_index(
245
+ llm_name, model_temperature, api_key
246
+ )
247
+ st.session_state["all_terms"] = {}
248
+
249
+ if "llama_index" in st.session_state:
250
+ query_text = st.text_input("Ask about a term or definition:")
251
+ if query_text:
252
+ query_text = (
253
+ query_text
254
+ + "\nIf you can't find the answer, answer the query with the best of your knowledge."
255
+ )
256
+ # breakpoint()
257
+ with st.spinner("Generating answer..."):
258
+ response = (
259
+ st.session_state["llama_index"]
260
+ .as_query_engine(
261
+ similarity_top_k=5,
262
+ response_mode="compact",
263
+ text_qa_template=TEXT_QA_TEMPLATE,
264
+ refine_template=DEFAULT_REFINE_PROMPT,
265
+ )
266
+ .query(query_text)
267
+ )
268
+ st.markdown(str(response))