Subha95 commited on
Commit
71b6f6e
Β·
verified Β·
1 Parent(s): 3569bcd

Update chatbot_rag.py

Browse files
Files changed (1) hide show
  1. chatbot_rag.py +57 -46
chatbot_rag.py CHANGED
@@ -1,18 +1,22 @@
1
  from langchain_community.vectorstores import Chroma
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.llms import HuggingFacePipeline
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
- from langchain.chains import RetrievalQA
6
  from langchain.prompts import PromptTemplate
 
 
7
  import traceback
8
 
 
9
  def build_qa():
10
- """Builds and returns the RAG QA pipeline."""
11
  print("πŸš€ Starting QA pipeline...")
12
 
13
  # 1. Embeddings
14
  print("πŸ”Ή Loading embeddings...")
15
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
16
 
17
  # 2. Load vector DB
18
  print("πŸ”Ή Loading Chroma DB...")
@@ -23,50 +27,66 @@ def build_qa():
23
  )
24
  print("πŸ“‚ Docs in DB:", vectorstore._collection.count())
25
 
26
- # 3. Load LLM (Flan-T5 small for lightweight QA)
27
  print("πŸ”Ή Loading LLM...")
28
  model_id = "microsoft/Phi-3-mini-4k-instruct"
29
  tokenizer = AutoTokenizer.from_pretrained(model_id)
30
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
31
 
32
  pipe = pipeline(
33
- "text-generation",
34
- model=model,
35
- tokenizer=tokenizer,
36
- max_new_tokens=300,
37
- do_sample=True, # Set to True to enable sampling and use temperature
38
- temperature=0.2 # This is the temperature parameter
39
  )
40
  llm = HuggingFacePipeline(pipeline=pipe)
41
 
42
- # 4. QA Chain with retrieval
43
- print("πŸ”Ή Building RetrievalQA...")
44
- retriever = vectorstore.as_retriever()
45
-
46
-
47
- template = """
48
- Use the following context to answer the question at the end.
49
- If you don't know the answer, just say "I don't know" β€” do not make up an answer.
50
-
51
- Context:
52
- {context}
53
-
54
- Question: {question}
55
- Answer (one short sentence):
56
- """
57
- qa_prompt = PromptTemplate(template=template, input_variables=["context", "question"])
58
-
59
- qa = RetrievalQA.from_chain_type(
60
- llm=llm,
61
- retriever=retriever,
62
- chain_type="stuff",
63
- chain_type_kwargs={"prompt": qa_prompt},
64
- return_source_documents=False,
65
  )
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  print("βœ… QA pipeline ready.")
69
- return qa
 
70
 
71
  # Build once
72
  try:
@@ -81,17 +101,8 @@ def get_answer(query: str) -> str:
81
  """Takes user query and returns chatbot response."""
82
  if qa_pipeline is None:
83
  return "⚠️ QA pipeline not initialized."
84
-
85
- try:
86
- retriever = qa_pipeline.retriever
87
- docs = retriever.get_relevant_documents(query)
88
- print("πŸ“‚ Retrieved docs:", len(docs))
89
- if not docs:
90
- return "⚠️ No documents found in the DB. Check your `db/` folder."
91
- except Exception as e:
92
- return f"❌ Retriever error: {e}"
93
 
94
  try:
95
- return qa_pipeline.run(query)
96
  except Exception as e:
97
  return f"❌ QA run failed: {e}"
 
1
  from langchain_community.vectorstores import Chroma
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.llms import HuggingFacePipeline
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
5
  from langchain.prompts import PromptTemplate
6
+ from langchain_core.runnables import RunnablePassthrough
7
+ from langchain_core.output_parsers import StrOutputParser
8
  import traceback
9
 
10
+
11
  def build_qa():
12
+ """Builds and returns the RAG QA pipeline (rag_chain style)."""
13
  print("πŸš€ Starting QA pipeline...")
14
 
15
  # 1. Embeddings
16
  print("πŸ”Ή Loading embeddings...")
17
+ embeddings = HuggingFaceEmbeddings(
18
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
19
+ )
20
 
21
  # 2. Load vector DB
22
  print("πŸ”Ή Loading Chroma DB...")
 
27
  )
28
  print("πŸ“‚ Docs in DB:", vectorstore._collection.count())
29
 
30
+ # 3. Load LLM (Phi-3 mini)
31
  print("πŸ”Ή Loading LLM...")
32
  model_id = "microsoft/Phi-3-mini-4k-instruct"
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
34
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
35
 
36
  pipe = pipeline(
37
+ "text-generation",
38
+ model=model,
39
+ tokenizer=tokenizer,
40
+ max_new_tokens=300,
41
+ do_sample=True,
42
+ temperature=0.2,
43
  )
44
  llm = HuggingFacePipeline(pipeline=pipe)
45
 
46
+ # 4. Retriever
47
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
48
+
49
+ # 5. Prompt
50
+ prompt = PromptTemplate(
51
+ input_variables=["context", "question"],
52
+ template="""
53
+ Use the following context to answer the question at the end.
54
+ If you don't know the answer, just say "I don't know" β€” do not make up an answer.
55
+
56
+ Context:
57
+ {context}
58
+
59
+ Question: {question}
60
+ Answer (one short sentence):
61
+ """,
 
 
 
 
 
 
 
62
  )
63
 
64
+ # 6. Helper functions
65
+ def format_docs(docs):
66
+ return "\n".join(doc.page_content for doc in docs)
67
+
68
+ def hf_to_str(x):
69
+ """Convert Hugging Face pipeline output to plain string"""
70
+ if isinstance(x, list) and "generated_text" in x[0]:
71
+ return x[0]["generated_text"]
72
+ return str(x)
73
+
74
+ # 7. RAG chain
75
+ rag_chain = (
76
+ {
77
+ "context": retriever | format_docs,
78
+ "question": RunnablePassthrough(),
79
+ }
80
+ | prompt
81
+ | (lambda x: str(x)) # convert PromptTemplate value to str
82
+ | llm
83
+ | (lambda x: hf_to_str(x)) # clean HF output
84
+ | StrOutputParser()
85
+ )
86
 
87
  print("βœ… QA pipeline ready.")
88
+ return rag_chain
89
+
90
 
91
  # Build once
92
  try:
 
101
  """Takes user query and returns chatbot response."""
102
  if qa_pipeline is None:
103
  return "⚠️ QA pipeline not initialized."
 
 
 
 
 
 
 
 
 
104
 
105
  try:
106
+ return qa_pipeline.invoke(query)
107
  except Exception as e:
108
  return f"❌ QA run failed: {e}"