Spaces:
Sleeping
Sleeping
Update chatbot_rag.py
Browse files- chatbot_rag.py +12 -19
chatbot_rag.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
from langchain_community.vectorstores import Chroma
|
| 2 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 3 |
from langchain_community.llms import HuggingFacePipeline
|
| 4 |
-
from transformers import AutoTokenizer,
|
| 5 |
from langchain.chains import RetrievalQA
|
| 6 |
-
import traceback
|
| 7 |
-
|
| 8 |
|
| 9 |
def build_qa():
|
| 10 |
"""Builds and returns the RAG QA pipeline."""
|
|
@@ -23,46 +22,40 @@ def build_qa():
|
|
| 23 |
)
|
| 24 |
print("π Docs in DB:", vectorstore._collection.count())
|
| 25 |
|
| 26 |
-
# 3. LLM
|
| 27 |
print("πΉ Loading LLM...")
|
| 28 |
-
model_id = "
|
| 29 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id
|
| 30 |
-
model =
|
| 31 |
-
model_id,
|
| 32 |
-
device_map="auto",
|
| 33 |
-
torch_dtype="auto"
|
| 34 |
-
)
|
| 35 |
-
print("β
LLM loaded.")
|
| 36 |
|
| 37 |
pipe = pipeline(
|
| 38 |
-
"
|
| 39 |
model=model,
|
| 40 |
tokenizer=tokenizer,
|
| 41 |
max_new_tokens=256,
|
| 42 |
-
temperature=0.2,
|
| 43 |
)
|
| 44 |
llm = HuggingFacePipeline(pipeline=pipe)
|
| 45 |
|
| 46 |
-
# 4. QA Chain
|
| 47 |
print("πΉ Building RetrievalQA...")
|
| 48 |
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
|
| 49 |
qa = RetrievalQA.from_chain_type(
|
| 50 |
llm=llm,
|
| 51 |
retriever=retriever,
|
| 52 |
-
return_source_documents=False
|
|
|
|
| 53 |
)
|
| 54 |
|
| 55 |
print("β
QA pipeline ready.")
|
| 56 |
return qa
|
| 57 |
|
| 58 |
-
|
| 59 |
-
# Build at import time (so it's ready when app runs)
|
| 60 |
try:
|
| 61 |
qa_pipeline = build_qa()
|
| 62 |
except Exception as e:
|
| 63 |
qa_pipeline = None
|
| 64 |
print("β Failed to build QA pipeline:", e)
|
| 65 |
-
traceback.print_exc()
|
| 66 |
|
| 67 |
|
| 68 |
def get_answer(query: str) -> str:
|
|
|
|
| 1 |
from langchain_community.vectorstores import Chroma
|
| 2 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 3 |
from langchain_community.llms import HuggingFacePipeline
|
| 4 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 5 |
from langchain.chains import RetrievalQA
|
| 6 |
+
import traceback
|
|
|
|
| 7 |
|
| 8 |
def build_qa():
|
| 9 |
"""Builds and returns the RAG QA pipeline."""
|
|
|
|
| 22 |
)
|
| 23 |
print("π Docs in DB:", vectorstore._collection.count())
|
| 24 |
|
| 25 |
+
# 3. Load LLM (Flan-T5 small for lightweight QA)
|
| 26 |
print("πΉ Loading LLM...")
|
| 27 |
+
model_id = "google/flan-t5-small"
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 29 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
pipe = pipeline(
|
| 32 |
+
"text2text-generation",
|
| 33 |
model=model,
|
| 34 |
tokenizer=tokenizer,
|
| 35 |
max_new_tokens=256,
|
|
|
|
| 36 |
)
|
| 37 |
llm = HuggingFacePipeline(pipeline=pipe)
|
| 38 |
|
| 39 |
+
# 4. QA Chain with retrieval
|
| 40 |
print("πΉ Building RetrievalQA...")
|
| 41 |
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
|
| 42 |
qa = RetrievalQA.from_chain_type(
|
| 43 |
llm=llm,
|
| 44 |
retriever=retriever,
|
| 45 |
+
return_source_documents=False,
|
| 46 |
+
chain_type="stuff" # simplest chain, passes context + question
|
| 47 |
)
|
| 48 |
|
| 49 |
print("β
QA pipeline ready.")
|
| 50 |
return qa
|
| 51 |
|
| 52 |
+
# Build once
|
|
|
|
| 53 |
try:
|
| 54 |
qa_pipeline = build_qa()
|
| 55 |
except Exception as e:
|
| 56 |
qa_pipeline = None
|
| 57 |
print("β Failed to build QA pipeline:", e)
|
| 58 |
+
traceback.print_exc()
|
| 59 |
|
| 60 |
|
| 61 |
def get_answer(query: str) -> str:
|