Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -66,23 +66,28 @@ def create_db(splits):
|
|
| 66 |
vectordb = FAISS.from_documents(splits, embeddings)
|
| 67 |
return vectordb
|
| 68 |
|
| 69 |
-
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
|
| 70 |
-
"""Initialize the LLM chain"""
|
|
|
|
| 71 |
llm = HuggingFaceEndpoint(
|
| 72 |
repo_id=llm_model,
|
| 73 |
huggingfacehub_api_token=api_token,
|
| 74 |
temperature=temperature,
|
| 75 |
-
|
| 76 |
-
top_k
|
| 77 |
)
|
| 78 |
|
|
|
|
| 79 |
memory = ConversationBufferMemory(
|
| 80 |
memory_key="chat_history",
|
| 81 |
output_key='answer',
|
| 82 |
return_messages=True
|
| 83 |
)
|
| 84 |
|
|
|
|
| 85 |
retriever = vector_db.as_retriever()
|
|
|
|
|
|
|
| 86 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 87 |
llm,
|
| 88 |
retriever=retriever,
|
|
@@ -93,6 +98,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
|
|
| 93 |
)
|
| 94 |
return qa_chain
|
| 95 |
|
|
|
|
| 96 |
def format_chat_history(message, chat_history):
|
| 97 |
"""Format chat history for the LLM"""
|
| 98 |
formatted_chat_history = []
|
|
|
|
| 66 |
vectordb = FAISS.from_documents(splits, embeddings)
|
| 67 |
return vectordb
|
| 68 |
|
| 69 |
+
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, api_token):
|
| 70 |
+
"""Initialize the LLM chain with a HuggingFace model"""
|
| 71 |
+
# Use valid Hugging Face parameters. `max_length` might be the correct field instead of `max_new_tokens`
|
| 72 |
llm = HuggingFaceEndpoint(
|
| 73 |
repo_id=llm_model,
|
| 74 |
huggingfacehub_api_token=api_token,
|
| 75 |
temperature=temperature,
|
| 76 |
+
max_length=max_tokens, # Adjusted from max_new_tokens to max_length
|
| 77 |
+
# Remove top_k as it may not be valid or handled differently
|
| 78 |
)
|
| 79 |
|
| 80 |
+
# Set up memory for conversation
|
| 81 |
memory = ConversationBufferMemory(
|
| 82 |
memory_key="chat_history",
|
| 83 |
output_key='answer',
|
| 84 |
return_messages=True
|
| 85 |
)
|
| 86 |
|
| 87 |
+
# Ensure vector_db is used as a retriever
|
| 88 |
retriever = vector_db.as_retriever()
|
| 89 |
+
|
| 90 |
+
# Initialize ConversationalRetrievalChain using LLM and the retriever
|
| 91 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 92 |
llm,
|
| 93 |
retriever=retriever,
|
|
|
|
| 98 |
)
|
| 99 |
return qa_chain
|
| 100 |
|
| 101 |
+
|
| 102 |
def format_chat_history(message, chat_history):
|
| 103 |
"""Format chat history for the LLM"""
|
| 104 |
formatted_chat_history = []
|