jsakshi commited on
Commit
ebdce26
·
verified ·
1 Parent(s): bc5c03f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -30,13 +30,18 @@ db = Chroma.from_documents(docs, embedding=DummyEmbeddings())
30
  retriever = db.as_retriever()
31
 
32
  # Step 4: Load a small open model instead of Mistral
33
- model_id = "google/flan-t5-base" # Or "microsoft/phi-2" (smaller)
 
 
 
 
34
  tokenizer = AutoTokenizer.from_pretrained(model_id)
35
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
36
 
37
- llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
38
  llm = HuggingFacePipeline(pipeline=llm_pipeline)
39
 
 
40
  # Step 5: RAG Chain
41
  qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
42
 
 
30
  retriever = db.as_retriever()
31
 
32
  # Step 4: Load a small open model instead of Mistral
33
+
34
+
35
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
36
+
37
+ model_id = "google/flan-t5-base"
38
  tokenizer = AutoTokenizer.from_pretrained(model_id)
39
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
40
 
41
+ llm_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
42
  llm = HuggingFacePipeline(pipeline=llm_pipeline)
43
 
44
+
45
  # Step 5: RAG Chain
46
  qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
47