Momal commited on
Commit
e866dc3
·
1 Parent(s): c88d08b

added app infra

Browse files
Files changed (3) hide show
  1. app.py +13 -2
  2. data/train.csv +0 -0
  3. mental_health_raqa.py +51 -4
app.py CHANGED
@@ -1,10 +1,21 @@
1
  import streamlit as st
 
2
 
3
- st.title('OpenAI Key Input')
 
4
 
5
  # Create a text input box for the OpenAI key
6
  openai_key = st.text_input('Enter your OpenAI Key', type='password')
7
 
8
  # Display the key when the user presses the 'Submit' button
9
  if st.button('Submit'):
10
- st.write(f'Your OpenAI Key is: {openai_key}')
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from mental_health_raqa import mh_assistant
3
 
4
+
5
+ st.title('Mental Health Assistant :broken_heart:')
6
 
7
  # Create a text input box for the OpenAI key
8
  openai_key = st.text_input('Enter your OpenAI Key', type='password')
9
 
10
  # Display the key when the user presses the 'Submit' button
11
  if st.button('Submit'):
12
+ if openai_key:
13
+ query = st.text_input('Enter your query', type='default')
14
+ if st.button('Ask'):
15
+ try:
16
+ response = mh_assistant(openai_key,query)
17
+ st.write(response)
18
+ except Exception as e:
19
+ st.error(f'An error occurred: {e}',icon=':no_entry_sign:')
20
+ else:
21
+ st.error('Please enter your OpenAI key',icon=':no_entry_sign:')
data/train.csv DELETED
The diff for this file is too large to render. See raw diff
 
mental_health_raqa.py CHANGED
@@ -1,6 +1,53 @@
1
  import pandas as pd
 
 
 
 
 
 
 
 
 
2
 
3
- def mental_health_raqa(openai_key):
4
- # Load the data
5
- data = pd.read_csv('data/raqa.csv')
6
- return openai_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
+ import os
3
+ from langchain.document_loaders.csv_loader import CSVLoader
4
+ from langchain.embeddings.openai import OpenAIEmbeddings
5
+ from langchain.embeddings import CacheBackedEmbeddings
6
+ from langchain.vectorstores import FAISS
7
+ from langchain.storage import LocalFileStore
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.callbacks import StdOutCallbackHandler
11
 
12
+ def create_index():
13
+ # load the data
14
+ df_path = 'data/mental_health_QnA2.csv'
15
+ loader = CSVLoader(file_path = df_path)
16
+ data = loader.load()
17
+
18
+ # create the embeddings model
19
+ embeddings_model = OpenAIEmbeddings()
20
+
21
+ # create the cache backed embeddings in vector store
22
+ store = LocalFileStore("./cache")
23
+ cached_embeder = CacheBackedEmbeddings.from_bytes_store(
24
+ embeddings_model, store, namespace=embeddings_model.model
25
+ )
26
+ vector_store = FAISS.from_documents(data, embeddings_model)
27
+
28
+ return vector_store.as_retriever()
29
+
30
+ def setup(openai_key):
31
+ # Set the API key for OpenAI
32
+ os.environ["OPENAI_API_KEY"] = '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'
33
+ retriver = create_index()
34
+ llm = ChatOpenAI(model="gpt-4")
35
+ return retriver, llm
36
+
37
+
38
+ def mh_assistant(openai_key,query):
39
+
40
+ # Setup
41
+ retriever,llm = setup(openai_key)
42
+ # Create the QA chain
43
+ handler = StdOutCallbackHandler()
44
+
45
+ qa_with_sources_chain = RetrievalQA.from_chain_type(
46
+ llm=llm,
47
+ retriever=retriever,
48
+ callbacks=[handler],
49
+ return_source_documents=True
50
+ )
51
+
52
+ # Ask a question
53
+ return qa_with_sources_chain({"query":query})