Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,9 +14,10 @@ from langchain_core.output_parsers import StrOutputParser
|
|
| 14 |
from langchain_core.runnables import RunnableLambda
|
| 15 |
from datetime import date
|
| 16 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 17 |
-
|
| 18 |
import threading
|
| 19 |
import time
|
|
|
|
|
|
|
| 20 |
# Environment variables
|
| 21 |
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
|
| 22 |
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
|
|
@@ -65,7 +66,7 @@ def retrieve_normal_context(retriever, question):
|
|
| 65 |
# Your OLMOLLM class implementation here (adapted for the Hugging Face model)
|
| 66 |
|
| 67 |
@st.cache_resource
|
| 68 |
-
def get_chain(temperature):
|
| 69 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
|
| 70 |
|
| 71 |
docstore_path = 'ohw_proj_chorma_db.pcl'
|
|
@@ -77,20 +78,26 @@ def get_chain(temperature):
|
|
| 77 |
child_splitter = RecursiveCharacterTextSplitter(chunk_size=300,
|
| 78 |
chunk_overlap=50)
|
| 79 |
retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter)
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
today = date.today()
|
|
@@ -147,8 +154,20 @@ def generate_response(chain, query, context):
|
|
| 147 |
# Sidebar
|
| 148 |
with st.sidebar:
|
| 149 |
st.title("OHW Assistant")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1)
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
st.button('Clear Chat History', on_click=clear_chat_history)
|
| 153 |
|
| 154 |
# Main app
|
|
|
|
| 14 |
from langchain_core.runnables import RunnableLambda
|
| 15 |
from datetime import date
|
| 16 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
|
|
| 17 |
import threading
|
| 18 |
import time
|
| 19 |
+
llm_list = ['Mistral-7B-Instruct-v0.2','Mixtral-8x7B-Instruct-v0.1','LLAMA3']
|
| 20 |
+
blablador_base = "https://helmholtz-blablador.fz-juelich.de:8000/v1"
|
| 21 |
# Environment variables
|
| 22 |
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
|
| 23 |
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
|
|
|
|
| 66 |
# Your OLMOLLM class implementation here (adapted for the Hugging Face model)
|
| 67 |
|
| 68 |
@st.cache_resource
|
| 69 |
+
def get_chain(temperature,selected_model):
|
| 70 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")
|
| 71 |
|
| 72 |
docstore_path = 'ohw_proj_chorma_db.pcl'
|
|
|
|
| 78 |
child_splitter = RecursiveCharacterTextSplitter(chunk_size=300,
|
| 79 |
chunk_overlap=50)
|
| 80 |
retriever = load_retriever(docstore_path,chroma_path,embeddings,child_splitter,parent_splitter)
|
| 81 |
+
llm_api = 'glpat-AMzMevbqaVjp4HbLcVum'
|
| 82 |
+
llm = ChatOpenAI(model_name=selected_model,
|
| 83 |
+
temperature=temperature,
|
| 84 |
+
openai_api_key=llm_api,
|
| 85 |
+
openai_api_base=blablador_base,
|
| 86 |
+
streaming=True)
|
| 87 |
+
# model, tokenizer = load_model()
|
| 88 |
+
|
| 89 |
+
# pipe = pipeline(
|
| 90 |
+
# "text-generation",
|
| 91 |
+
# model=model,
|
| 92 |
+
# tokenizer=tokenizer,
|
| 93 |
+
# max_length=1800,
|
| 94 |
+
# max_new_tokens = 200,
|
| 95 |
+
# temperature=temperature,
|
| 96 |
+
# top_p=0.95,
|
| 97 |
+
# repetition_penalty=1.15
|
| 98 |
+
# )
|
| 99 |
+
|
| 100 |
+
# llm = HuggingFacePipeline(pipeline=pipe)
|
| 101 |
|
| 102 |
|
| 103 |
today = date.today()
|
|
|
|
| 154 |
# Sidebar
|
| 155 |
with st.sidebar:
|
| 156 |
st.title("OHW Assistant")
|
| 157 |
+
selected_model = st.sidebar.selectbox('Choose a LLM model',
|
| 158 |
+
llm_list,
|
| 159 |
+
key='selected_model',
|
| 160 |
+
index = None)
|
| 161 |
+
|
| 162 |
temperature = st.slider("Temperature: ", 0.0, 1.0, 0.5, 0.1)
|
| 163 |
+
if selected_model in ['Mistral-7B-Instruct-v0.2', 'Mixtral-8x7B-Instruct-v0.1','LLAMA3']:
|
| 164 |
+
if selected_model == 'Mistral-7B-Instruct-v0.2':
|
| 165 |
+
selected_model = 'alias-fast'
|
| 166 |
+
elif selected_model == 'Mixtral-8x7B-Instruct-v0.1':
|
| 167 |
+
selected_model = 'alias-large'
|
| 168 |
+
elif selected_model == 'LLAMA3':
|
| 169 |
+
selected_model = 'alias-experimental'
|
| 170 |
+
chain = get_chain(temperature,selected_model)
|
| 171 |
st.button('Clear Chat History', on_click=clear_chat_history)
|
| 172 |
|
| 173 |
# Main app
|