udituen commited on
Commit
bacd419
·
1 Parent(s): 7ff4c08

change llm

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +9 -11
src/streamlit_app.py CHANGED
@@ -9,15 +9,11 @@ from langchain.chains import create_retrieval_chain
9
  from langchain.chains.combine_documents import create_stuff_documents_chain
10
  from langchain_community.llms import Ollama
11
  import os
 
 
12
 
13
 
14
  # ----------------------
15
- # HF_CACHE_PATH = "./app_cache"
16
- # # os.makedirs(HF_CACHE_PATH, exist_ok=True)
17
- # os.environ["TRANSFORMERS_CACHE"] = HF_CACHE_PATH
18
- # os.environ["HF_HOME"] = HF_CACHE_PATH
19
-
20
-
21
  system_prompt = (
22
  "You are an agriultural research assistant."
23
  "Use the given context to answer the question."
@@ -35,10 +31,6 @@ prompt = ChatPromptTemplate.from_messages(
35
  # Initialize embeddings & documents
36
  @st.cache_resource
37
  def load_retriever():
38
- # Load documents
39
- # with open("data/docs.txt", "r") as f:
40
- # docs = f.read().split("\n")
41
- # Later load
42
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
43
  db = FAISS.load_local("./vectorstore/agriquery_faiss_index", embeddings, allow_dangerous_deserialization=True)
44
  retriever = db.as_retriever()
@@ -47,7 +39,13 @@ def load_retriever():
47
  # Load a lightweight model via HuggingFace pipeline
48
  @st.cache_resource
49
  def load_llm():
50
- pipe = pipeline("text-generation", model="google/flan-t5-small", max_new_tokens=256)
 
 
 
 
 
 
51
  return HuggingFacePipeline(pipeline=pipe)
52
 
53
  # Setup RAG Chain
 
9
  from langchain.chains.combine_documents import create_stuff_documents_chain
10
  from langchain_community.llms import Ollama
11
  import os
12
+ import torch
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
14
 
15
 
16
  # ----------------------
 
 
 
 
 
 
17
  system_prompt = (
18
  "You are an agriultural research assistant."
19
  "Use the given context to answer the question."
 
31
  # Initialize embeddings & documents
32
  @st.cache_resource
33
  def load_retriever():
 
 
 
 
34
  embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
35
  db = FAISS.load_local("./vectorstore/agriquery_faiss_index", embeddings, allow_dangerous_deserialization=True)
36
  retriever = db.as_retriever()
 
39
  # Load a lightweight model via HuggingFace pipeline
40
  @st.cache_resource
41
  def load_llm():
42
+ # pipe = pipeline("text-generation", model="google/flan-t5-small", max_new_tokens=256)
43
+
44
+ # load the tokenizer and model on cpu/gpu
45
+ model_name = "meta-llama/Llama-2-7b-chat-hf"
46
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
47
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
48
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
49
  return HuggingFacePipeline(pipeline=pipe)
50
 
51
  # Setup RAG Chain