Subha95 commited on
Commit
81f51c9
Β·
verified Β·
1 Parent(s): 4356a00

Update chatbot_rag.py

Browse files
Files changed (1) hide show
  1. chatbot_rag.py +12 -19
chatbot_rag.py CHANGED
@@ -1,10 +1,9 @@
1
  from langchain_community.vectorstores import Chroma
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.llms import HuggingFacePipeline
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
  from langchain.chains import RetrievalQA
6
- import traceback # βœ… added
7
-
8
 
9
  def build_qa():
10
  """Builds and returns the RAG QA pipeline."""
@@ -23,46 +22,40 @@ def build_qa():
23
  )
24
  print("πŸ“‚ Docs in DB:", vectorstore._collection.count())
25
 
26
- # 3. LLM
27
  print("πŸ”Ή Loading LLM...")
28
- model_id = "sshleifer/tiny-gpt2"
29
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
30
- model = AutoModelForCausalLM.from_pretrained(
31
- model_id,
32
- device_map="auto",
33
- torch_dtype="auto"
34
- )
35
- print("βœ… LLM loaded.")
36
 
37
  pipe = pipeline(
38
- "question-answering",
39
  model=model,
40
  tokenizer=tokenizer,
41
  max_new_tokens=256,
42
- temperature=0.2,
43
  )
44
  llm = HuggingFacePipeline(pipeline=pipe)
45
 
46
- # 4. QA Chain
47
  print("πŸ”Ή Building RetrievalQA...")
48
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
49
  qa = RetrievalQA.from_chain_type(
50
  llm=llm,
51
  retriever=retriever,
52
- return_source_documents=False
 
53
  )
54
 
55
  print("βœ… QA pipeline ready.")
56
  return qa
57
 
58
-
59
- # Build at import time (so it's ready when app runs)
60
  try:
61
  qa_pipeline = build_qa()
62
  except Exception as e:
63
  qa_pipeline = None
64
  print("❌ Failed to build QA pipeline:", e)
65
- traceback.print_exc() # βœ… added: full error details
66
 
67
 
68
  def get_answer(query: str) -> str:
 
1
  from langchain_community.vectorstores import Chroma
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
  from langchain_community.llms import HuggingFacePipeline
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
  from langchain.chains import RetrievalQA
6
+ import traceback
 
7
 
8
  def build_qa():
9
  """Builds and returns the RAG QA pipeline."""
 
22
  )
23
  print("πŸ“‚ Docs in DB:", vectorstore._collection.count())
24
 
25
+ # 3. Load LLM (Flan-T5 small for lightweight QA)
26
  print("πŸ”Ή Loading LLM...")
27
+ model_id = "google/flan-t5-small"
28
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
29
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
 
 
 
 
 
30
 
31
  pipe = pipeline(
32
+ "text2text-generation",
33
  model=model,
34
  tokenizer=tokenizer,
35
  max_new_tokens=256,
 
36
  )
37
  llm = HuggingFacePipeline(pipeline=pipe)
38
 
39
+ # 4. QA Chain with retrieval
40
  print("πŸ”Ή Building RetrievalQA...")
41
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
42
  qa = RetrievalQA.from_chain_type(
43
  llm=llm,
44
  retriever=retriever,
45
+ return_source_documents=False,
46
+ chain_type="stuff" # simplest chain, passes context + question
47
  )
48
 
49
  print("βœ… QA pipeline ready.")
50
  return qa
51
 
52
+ # Build once
 
53
  try:
54
  qa_pipeline = build_qa()
55
  except Exception as e:
56
  qa_pipeline = None
57
  print("❌ Failed to build QA pipeline:", e)
58
+ traceback.print_exc()
59
 
60
 
61
  def get_answer(query: str) -> str: