Subha95 commited on
Commit
abd8f5a
·
verified ·
1 Parent(s): 9087b24

Update chatbot_rag.py

Browse files
Files changed (1) hide show
  1. chatbot_rag.py +23 -18
chatbot_rag.py CHANGED
@@ -10,34 +10,39 @@ from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
10
  from langchain_chroma import Chroma
11
 
12
  def build_qa():
13
- """Builds and returns the RAG QA pipeline."""
14
-
15
- # 1. Load embeddings + DB
16
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
17
  vectorstore = Chroma(
18
  persist_directory="db",
19
  collection_name="rag-docs",
20
  embedding_function=embeddings,
21
  )
22
 
23
- # 2. LLM (instruction-tuned preferred)
24
  model_id = "microsoft/phi-3-mini-4k-instruct"
25
- tokenizer = AutoTokenizer.from_pretrained(model_id)
26
- model = AutoModelForCausalLM.from_pretrained(model_id)
27
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
28
- llm = HuggingFacePipeline(pipeline=pipe)
29
 
30
- # 3. QA Chain
31
- retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
32
- qa = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, return_source_documents=False)
33
-
34
- return qa
 
35
 
 
 
 
 
 
 
 
36
 
37
- # Build once (so Hugging Face loads at startup)
38
- qa_pipeline = build_qa()
39
 
 
 
 
 
 
 
40
 
41
- def get_answer(query: str) -> str:
42
- """Takes user query and returns chatbot response."""
43
- return qa_pipeline.run(query)
 
10
  from langchain_chroma import Chroma
11
 
12
  def build_qa():
 
 
 
13
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
14
+
15
  vectorstore = Chroma(
16
  persist_directory="db",
17
  collection_name="rag-docs",
18
  embedding_function=embeddings,
19
  )
20
 
21
+ # 🔹 Use Phi-3 Mini (smaller, faster)
22
  model_id = "microsoft/phi-3-mini-4k-instruct"
 
 
 
 
23
 
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_id,
27
+ device_map="auto", # ✅ auto place on GPU if available
28
+ torch_dtype="auto" # ✅ better memory handling
29
+ )
30
 
31
+ pipe = pipeline(
32
+ "text-generation",
33
+ model=model,
34
+ tokenizer=tokenizer,
35
+ max_new_tokens=256, # ✅ smaller output (faster)
36
+ temperature=0.2, # ✅ more focused answers
37
+ )
38
 
39
+ llm = HuggingFacePipeline(pipeline=pipe)
 
40
 
41
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
42
+ qa = RetrievalQA.from_chain_type(
43
+ llm=llm,
44
+ retriever=retriever,
45
+ return_source_documents=False
46
+ )
47
 
48
+ return qa