rag-embedder / app /langchain_rag.py
jackenmail's picture
Upload 3 files
41ac698 verified
# ─────────────────────────────────────────────────────────────
# app/langchain_rag.py
# LangChain version of the RAG pipeline
# ─────────────────────────────────────────────────────────────
import os
import sys
import numpy as np
from dotenv import load_dotenv
load_dotenv()
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from gradio_client import Client
from langchain.embeddings.base import Embeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceHub
from langchain.chains import RetrievalQA
from langchain.schema import Document
# ── Wrap your HF Gradio Space as LangChain Embeddings ────────
class GradioEmbeddings(Embeddings):
"""
LangChain-compatible wrapper around your
HF Gradio Space embedding API.
"""
def __init__(self, space: str = None):
self.space = space or os.getenv("GRADIO_SPACE", "your-username/rag-embedder-app")
self.client = Client(self.space)
print(f"Connected to Gradio Space: {self.space}")
def embed_documents(self, texts: list) -> list:
return [self.client.predict(t, api_name="/predict") for t in texts]
def embed_query(self, text: str) -> list:
return self.client.predict(text, api_name="/predict")
# ── Load documents ────────────────────────────────────────────
def load_documents(path: str) -> list:
with open(path) as f:
lines = [line.strip() for line in f if line.strip()]
return [Document(page_content=line) for line in lines]
# ── Build LangChain RAG chain ─────────────────────────────────
def build_rag_chain():
docs_path = os.getenv("DOCS_PATH", "data/sample_docs.txt")
hf_token = os.getenv("HF_TOKEN", "")
llm_model = os.getenv("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.1")
print("Setting up LangChain RAG pipeline...")
# Load docs
documents = load_documents(docs_path)
print(f"Loaded {len(documents)} documents")
# Embeddings via your HF Gradio Space
embeddings = GradioEmbeddings()
# Vector store
vectorstore = FAISS.from_documents(documents, embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# LLM via HF Hub
llm = HuggingFaceHub(
repo_id = llm_model,
huggingfacehub_api_token = hf_token,
model_kwargs = {"max_new_tokens": 200, "temperature": 0.3}
)
# Full RAG chain
chain = RetrievalQA.from_chain_type(
llm = llm,
retriever = retriever,
chain_type= "stuff",
return_source_documents = True
)
print("LangChain RAG chain ready!")
return chain
# ── Run ───────────────────────────────────────────────────────
if __name__ == "__main__":
chain = build_rag_chain()
questions = [
"What is the refund policy?",
"How do I reset my password?",
"When can I contact support?"
]
print("\n" + "=" * 55)
for q in questions:
result = chain({"query": q})
answer = result["result"]
sources = [doc.page_content for doc in result["source_documents"]]
print(f"Q: {q}")
print(f"A: {answer}")
print(f"Sources: {sources[:2]}")
print("-" * 55)