udituen commited on
Commit
660ad64
·
1 Parent(s): 987cbb7

fixing llm response

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +6 -5
src/streamlit_app.py CHANGED
@@ -40,13 +40,14 @@ def load_retriever():
40
  # Load a lightweight model via HuggingFace pipeline
41
  @st.cache_resource
42
  def load_llm():
43
- pipe = pipeline("text-generation", model="google/flan-t5-small", max_new_tokens=256)
44
  # load the tokenizer and model on cpu/gpu
45
 
46
  model_name = "meta-llama/Llama-2-7b-chat-hf"
47
  tokenizer = AutoTokenizer.from_pretrained(model_name)
48
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
49
- # pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
 
50
  return HuggingFacePipeline(pipeline=pipe)
51
 
52
  # Setup RAG Chain
@@ -56,10 +57,10 @@ def setup_qa():
56
  retriever = load_retriever()
57
  llm = load_llm()
58
  question_answer_chain = create_stuff_documents_chain(llm,prompt)
59
- chain = create_retrieval_chain(retriever, question_answer_chain)
60
 
61
- # qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
62
- return chain
63
 
64
 
65
  # Streamlit App UI
 
40
  # Load a lightweight model via HuggingFace pipeline
41
  @st.cache_resource
42
  def load_llm():
43
+ # pipe = pipeline("text-generation", model="google/flan-t5-small", max_new_tokens=256)
44
  # load the tokenizer and model on cpu/gpu
45
 
46
  model_name = "meta-llama/Llama-2-7b-chat-hf"
47
  tokenizer = AutoTokenizer.from_pretrained(model_name)
48
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
49
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
50
+
51
  return HuggingFacePipeline(pipeline=pipe)
52
 
53
  # Setup RAG Chain
 
57
  retriever = load_retriever()
58
  llm = load_llm()
59
  question_answer_chain = create_stuff_documents_chain(llm,prompt)
60
+ # chain = create_retrieval_chain(retriever, question_answer_chain)
61
 
62
+ qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
63
+ return qa_chain
64
 
65
 
66
  # Streamlit App UI