Subha95 commited on
Commit
31ce18a
·
verified ·
1 Parent(s): abd8f5a

Update chatbot_rag.py

Browse files
Files changed (1) hide show
  1. chatbot_rag.py +27 -11
chatbot_rag.py CHANGED
@@ -1,43 +1,44 @@
1
 
2
-
3
- # rag_pipeline.py
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
- from langchain.llms import HuggingFacePipeline
8
  from langchain.chains import RetrievalQA
9
- from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
10
- from langchain_chroma import Chroma
11
 
12
  def build_qa():
 
 
 
13
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
14
 
 
15
  vectorstore = Chroma(
16
  persist_directory="db",
17
  collection_name="rag-docs",
18
  embedding_function=embeddings,
19
  )
20
 
21
- # 🔹 Use Phi-3 Mini (smaller, faster)
22
  model_id = "microsoft/phi-3-mini-4k-instruct"
23
-
24
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
- device_map="auto", # ✅ auto place on GPU if available
28
- torch_dtype="auto" # ✅ better memory handling
29
  )
30
 
31
  pipe = pipeline(
32
  "text-generation",
33
  model=model,
34
  tokenizer=tokenizer,
35
- max_new_tokens=256, # ✅ smaller output (faster)
36
- temperature=0.2, # ✅ more focused answers
37
  )
38
 
39
  llm = HuggingFacePipeline(pipeline=pipe)
40
 
 
41
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
42
  qa = RetrievalQA.from_chain_type(
43
  llm=llm,
@@ -46,3 +47,18 @@ def build_qa():
46
  )
47
 
48
  return qa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
 
 
2
  from langchain_community.vectorstores import Chroma
3
  from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain_community.llms import HuggingFacePipeline
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
6
  from langchain.chains import RetrievalQA
7
+
 
8
 
9
  def build_qa():
10
+ """Builds and returns the RAG QA pipeline."""
11
+
12
+ # 1. Embeddings
13
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
14
 
15
+ # 2. Load vector DB (must already exist in ./db)
16
  vectorstore = Chroma(
17
  persist_directory="db",
18
  collection_name="rag-docs",
19
  embedding_function=embeddings,
20
  )
21
 
22
+ # 3. LLM (lighter model = faster inference)
23
  model_id = "microsoft/phi-3-mini-4k-instruct"
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
+ device_map="auto",
28
+ torch_dtype="auto"
29
  )
30
 
31
  pipe = pipeline(
32
  "text-generation",
33
  model=model,
34
  tokenizer=tokenizer,
35
+ max_new_tokens=256,
36
+ temperature=0.2,
37
  )
38
 
39
  llm = HuggingFacePipeline(pipeline=pipe)
40
 
41
+ # 4. RAG chain
42
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
43
  qa = RetrievalQA.from_chain_type(
44
  llm=llm,
 
47
  )
48
 
49
  return qa
50
+
51
+
52
+ # Build at import time (so it's ready when app runs)
53
+ try:
54
+ qa_pipeline = build_qa()
55
+ except Exception as e:
56
+ qa_pipeline = None
57
+ print("❌ Failed to build QA pipeline:", e)
58
+
59
+
60
+ def get_answer(query: str) -> str:
61
+ """Takes user query and returns chatbot response."""
62
+ if qa_pipeline is None:
63
+ return "⚠️ QA pipeline not initialized."
64
+ return qa_pipeline.run(query)