Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import ollama | |
| import bs4 | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import OllamaEmbeddings | |
| # Check if user has inputted a URL or uploaded a document and load, split, and retrieve documents | |
| def load_and_retrieve(url, document): | |
| # If user has inputted a URL | |
| if url: | |
| loader = WebBaseLoader( | |
| web_paths=(url,), | |
| bs_kwargs=dict() | |
| ) | |
| docs = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=200) | |
| splits = text_splitter.split_documents(docs) | |
| embeddings = OllamaEmbeddings(model="nomic-embed-text") | |
| vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) | |
| return vectorstore.as_retriever() | |
| # If user has uploaded a document | |
| if document: | |
| loader = PyPDFLoader(document) | |
| docs = loader.load_and_split() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=200) | |
| splits = text_splitter.split_documents(docs) | |
| embeddings = OllamaEmbeddings(model="nomic-embed-text") | |
| vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) | |
| return vectorstore.as_retriever() | |
| # Function to format documents | |
| def format_docs(docs): | |
| # Return the page content of each document | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| # Function that defines the RAG chain | |
| def rag_chain(url = False, document = False, question = ''): | |
| retriever = load_and_retrieve(url, document) | |
| retrieved_docs = retriever.invoke(question) | |
| formatted_context = format_docs(retrieved_docs) | |
| formatted_prompt = f"Question: {question}\n\nContext: {formatted_context}" | |
| print("==============") | |
| print(formatted_prompt) | |
| print("==============") | |
| response = ollama.chat(model='llama3', messages=[{'role': 'user', 'content': formatted_prompt}]) | |
| return response['message']['content'] | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=rag_chain, | |
| inputs=["text", "file", "text"], | |
| outputs="text", | |
| title="RAG Chain Question Answering", | |
| description="Enter a URL or upload a document and a query to get answers from the RAG chain." | |
| ) | |
| # Launch the app | |
| iface.launch(share=True) |