|
|
import gradio as gr |
|
|
from bs4 import BeautifulSoup as bs |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_community.document_loaders import WebBaseLoader |
|
|
from langchain_community.vectorstores import Chroma |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import pipeline |
|
|
__import__('pysqlite3') |
|
|
import sys |
|
|
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') |
|
|
|
|
|
|
|
|
def load_and_retrieve_docs(url): |
|
|
loader = WebBaseLoader( |
|
|
web_paths=(url,), |
|
|
bs_kwargs=dict() |
|
|
) |
|
|
docs = loader.load() |
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
|
splits = text_splitter.split_documents(docs) |
|
|
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
|
|
|
|
|
class CustomEmbeddings: |
|
|
def __init__(self, model): |
|
|
self.model = model |
|
|
|
|
|
def embed_documents(self, texts): |
|
|
return self.model.encode(texts, convert_to_tensor=True).tolist() |
|
|
|
|
|
embeddings = CustomEmbeddings(embedding_model) |
|
|
|
|
|
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) |
|
|
return vectorstore.as_retriever() |
|
|
|
|
|
|
|
|
def format_docs(docs): |
|
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
|
|
|
|
def rag_chain(url, question): |
|
|
retriever = load_and_retrieve_docs(url) |
|
|
retrieved_docs = retriever.invoke(question) |
|
|
formatted_context = format_docs(retrieved_docs) |
|
|
formatted_prompt = f"Question: {question}\n\nContext: {formatted_context}" |
|
|
|
|
|
|
|
|
qa_pipeline = pipeline("text-generation", model="gpt-2") |
|
|
response = qa_pipeline(formatted_prompt, max_length=200) |
|
|
|
|
|
return response[0]['generated_text'] |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=rag_chain, |
|
|
inputs=["text", "text"], |
|
|
outputs="text", |
|
|
title="RAG Chain Question Answering", |
|
|
description="Enter a URL and a query to get answers from the RAG chain." |
|
|
) |
|
|
|
|
|
|
|
|
iface.launch() |
|
|
|