Spaces:
Runtime error
Runtime error
feedback
Browse files- app.py +25 -10
- rag/rag.py +1 -1
app.py
CHANGED
|
@@ -15,9 +15,11 @@ st.set_page_config(
|
|
| 15 |
menu_items=None,
|
| 16 |
)
|
| 17 |
|
|
|
|
|
|
|
| 18 |
WANDB_PROJECT = "paper_reader"
|
| 19 |
|
| 20 |
-
weave.init(f"{WANDB_PROJECT}")
|
| 21 |
|
| 22 |
st.title("Chat with the Llama 3 paper π¬π¦")
|
| 23 |
|
|
@@ -33,16 +35,29 @@ with st.spinner('Loading the RAG pipeline...'):
|
|
| 33 |
if "rag_pipeline" not in st.session_state.keys():
|
| 34 |
st.session_state.rag_pipeline = load_rag_pipeline()
|
| 35 |
|
| 36 |
-
rag_pipeline = st.session_state["rag_pipeline"]
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def generate_response(query):
|
| 40 |
-
response = rag_pipeline.predict(query)
|
| 41 |
-
st.write_stream(response.response_gen)
|
| 42 |
-
|
| 43 |
|
| 44 |
with st.form("my_form"):
|
| 45 |
query = st.text_area("Ask your question about the Llama 3 paper here:")
|
| 46 |
submitted = st.form_submit_button("Submit")
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
menu_items=None,
|
| 16 |
)
|
| 17 |
|
| 18 |
+
st.session_state['session_id'] = '123abc'
|
| 19 |
+
|
| 20 |
WANDB_PROJECT = "paper_reader"
|
| 21 |
|
| 22 |
+
weave_client = weave.init(f"{WANDB_PROJECT}")
|
| 23 |
|
| 24 |
st.title("Chat with the Llama 3 paper π¬π¦")
|
| 25 |
|
|
|
|
| 35 |
if "rag_pipeline" not in st.session_state.keys():
|
| 36 |
st.session_state.rag_pipeline = load_rag_pipeline()
|
| 37 |
|
| 38 |
+
rag_pipeline = st.session_state["rag_pipeline"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
with st.form("my_form"):
|
| 41 |
query = st.text_area("Ask your question about the Llama 3 paper here:")
|
| 42 |
submitted = st.form_submit_button("Submit")
|
| 43 |
+
|
| 44 |
+
if submitted:
|
| 45 |
+
with st.spinner('Generating answer...'):
|
| 46 |
+
output = rag_pipeline.predict(query)
|
| 47 |
+
st.session_state["last_output"] = output
|
| 48 |
+
text = ""
|
| 49 |
+
for t in output["response"].response_gen:
|
| 50 |
+
text += t
|
| 51 |
+
st.session_state["last_text"] = text
|
| 52 |
+
|
| 53 |
+
st.write_stream(output["response"].response_gen)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if "last_output" in st.session_state:
|
| 57 |
+
output = st.session_state["last_output"]
|
| 58 |
+
text = st.session_state["last_text"]
|
| 59 |
+
st.write(text)
|
| 60 |
+
|
| 61 |
+
# use the weave client to retrieve the call and attach feedback
|
| 62 |
+
st.button(":thumbsup:", on_click=lambda: weave_client.call(output['call_id']).feedback.add_reaction("π"), key='up')
|
| 63 |
+
st.button(":thumbsdown:", on_click=lambda: weave_client.call(output['call_id']).feedback.add_reaction("π"), key='down')
|
rag/rag.py
CHANGED
|
@@ -150,7 +150,7 @@ class SimpleRAGPipeline(weave.Model):
|
|
| 150 |
@weave.op()
|
| 151 |
def predict(self, question: str):
|
| 152 |
response = self.query_engine.query(question)
|
| 153 |
-
return response
|
| 154 |
|
| 155 |
|
| 156 |
if __name__ == "__main__":
|
|
|
|
| 150 |
@weave.op()
|
| 151 |
def predict(self, question: str):
|
| 152 |
response = self.query_engine.query(question)
|
| 153 |
+
return {"response": response, 'call_id': weave.get_current_call().id}
|
| 154 |
|
| 155 |
|
| 156 |
if __name__ == "__main__":
|