fixing llm response
Browse files- src/streamlit_app.py +6 -5
src/streamlit_app.py
CHANGED
|
@@ -40,13 +40,14 @@ def load_retriever():
|
|
| 40 |
# Load a lightweight model via HuggingFace pipeline
|
| 41 |
@st.cache_resource
|
| 42 |
def load_llm():
|
| 43 |
-
pipe = pipeline("text-generation", model="google/flan-t5-small", max_new_tokens=256)
|
| 44 |
# load the tokenizer and model on cpu/gpu
|
| 45 |
|
| 46 |
model_name = "meta-llama/Llama-2-7b-chat-hf"
|
| 47 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 48 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
| 49 |
-
|
|
|
|
| 50 |
return HuggingFacePipeline(pipeline=pipe)
|
| 51 |
|
| 52 |
# Setup RAG Chain
|
|
@@ -56,10 +57,10 @@ def setup_qa():
|
|
| 56 |
retriever = load_retriever()
|
| 57 |
llm = load_llm()
|
| 58 |
question_answer_chain = create_stuff_documents_chain(llm,prompt)
|
| 59 |
-
chain = create_retrieval_chain(retriever, question_answer_chain)
|
| 60 |
|
| 61 |
-
|
| 62 |
-
return
|
| 63 |
|
| 64 |
|
| 65 |
# Streamlit App UI
|
|
|
|
| 40 |
# Load a lightweight model via HuggingFace pipeline
|
| 41 |
@st.cache_resource
|
| 42 |
def load_llm():
|
| 43 |
+
# pipe = pipeline("text-generation", model="google/flan-t5-small", max_new_tokens=256)
|
| 44 |
# load the tokenizer and model on cpu/gpu
|
| 45 |
|
| 46 |
model_name = "meta-llama/Llama-2-7b-chat-hf"
|
| 47 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 48 |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
|
| 49 |
+
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
|
| 50 |
+
|
| 51 |
return HuggingFacePipeline(pipeline=pipe)
|
| 52 |
|
| 53 |
# Setup RAG Chain
|
|
|
|
| 57 |
retriever = load_retriever()
|
| 58 |
llm = load_llm()
|
| 59 |
question_answer_chain = create_stuff_documents_chain(llm,prompt)
|
| 60 |
+
# chain = create_retrieval_chain(retriever, question_answer_chain)
|
| 61 |
|
| 62 |
+
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
|
| 63 |
+
return qa_chain
|
| 64 |
|
| 65 |
|
| 66 |
# Streamlit App UI
|