rag / app.py
sujoykumarhens's picture
Create app.py
2853482 verified
import os
import pytesseract
import requests
from bs4 import BeautifulSoup
from PIL import Image
from pdf2image import convert_from_path
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.llms import Ollama
from langchain.chains import RetrievalQA
from langchain.schema import Document
import gradio as gr
# ========== 1. Load Local Documents with OCR Support ==========
def load_local_documents(paths):
all_docs = []
for path in paths:
if path.lower().endswith(".txt"):
with open(path, 'r', encoding='utf-8') as f:
text = f.read()
all_docs.append(Document(page_content=text, metadata={"source": path}))
elif path.lower().endswith(".pdf"):
try:
from langchain.document_loaders import PyPDFLoader
all_docs.extend(PyPDFLoader(path).load())
except:
pages = convert_from_path(path)
for i, page in enumerate(pages):
text = pytesseract.image_to_string(page)
all_docs.append(Document(page_content=text, metadata={"page": i, "source": path}))
elif path.lower().endswith((".png", ".jpg", ".jpeg")):
img = Image.open(path)
text = pytesseract.image_to_string(img)
all_docs.append(Document(page_content=text, metadata={"source": path}))
return all_docs
# ========== 2. Crawl a Website and Extract Text ==========
def scrape_website(url):
try:
response = requests.get(url, timeout=10)
soup = BeautifulSoup(response.text, "html.parser")
text = soup.get_text(separator="\n")
return [Document(page_content=text, metadata={"source": url})]
except Exception as e:
print(f"Failed to scrape {url}: {e}")
return []
# ========== 3. Chunk, Embed, and Store in FAISS ==========
def build_vector_db(documents, db_path="faiss_index"):
splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = splitter.split_documents(documents)
embed_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = FAISS.from_documents(chunks, embed_model)
vectordb.save_local(db_path)
return vectordb
# ========== 4. Set Up RAG Chain with Local LLM (Ollama) ==========
def get_rag_chain(db_path="faiss_index"):
embed_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = FAISS.load_local(db_path, embed_model)
retriever = vectordb.as_retriever(search_type="similarity", k=3)
llm = Ollama(model="mistral") # make sure this model is running in Ollama
return RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
# ========== 5. Gradio UI ==========
qa_chain = None
def run_query(question):
if not qa_chain:
return "Please load documents first."
result = qa_chain({"query": question})
return result["result"]
def load_all_docs(local_files, website_urls):
docs = []
if local_files:
docs.extend(load_local_documents(local_files))
if website_urls:
for url in website_urls.split(","):
docs.extend(scrape_website(url.strip()))
build_vector_db(docs)
global qa_chain
qa_chain = get_rag_chain()
return f"Indexed {len(docs)} documents. Ready to answer queries!"
demo = gr.Interface(
title="๐Ÿ“š Local RAG App",
fn=run_query,
inputs=gr.Textbox(placeholder="Ask your question..."),
outputs="text",
description="Load local files & websites, then ask questions below.",
)
load_interface = gr.Interface(
fn=load_all_docs,
inputs=[
gr.File(file_types=[".txt", ".pdf", ".jpg", ".png"], file_count="multiple", label="Upload Files"),
gr.Textbox(placeholder="https://example.com, https://another.com", label="Website URLs (comma-separated)")
],
outputs="text",
title="๐Ÿ—‚๏ธ Load Your Documents",
)
app = gr.TabbedInterface([load_interface, demo], ["Load Docs", "Ask Questions"])
app.launch()