vitamin-bot / app.py
farnoorcoder's picture
updated app.py
2a863b1 verified
import os
import zipfile
import gradio as gr
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.schema.runnable import Runnable
DATA_DIR = "Week14_content"
ZIP_FILE = "Week_14__MLS14 - Adv RAG.zip"
def unzip_if_needed():
if not os.path.exists(DATA_DIR):
with zipfile.ZipFile(ZIP_FILE, 'r') as zip_ref:
zip_ref.extractall(DATA_DIR)
def load_documents():
documents = []
for root, _, files in os.walk(DATA_DIR):
for file in files:
if file.endswith(".pdf"):
loader = PyPDFLoader(os.path.join(root, file))
documents.extend(loader.load())
return documents
def build_rag_chain(api_key: str) -> Runnable:
unzip_if_needed()
docs = load_documents()
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
split_docs = splitter.split_documents(docs)
embedding = OpenAIEmbeddings(openai_api_key=api_key)
vectorstore = Chroma.from_documents(split_docs, embedding)
retriever = vectorstore.as_retriever()
llm = ChatOpenAI(openai_api_key=api_key, model="gpt-4-turbo", temperature=0)
qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
return qa_chain
def query_rag(api_key: str, user_question: str) -> str:
if not api_key or not user_question:
return "Please provide both your OpenAI API key and a question."
try:
chain = build_rag_chain(api_key)
result = chain.run(user_question)
return result
except Exception as e:
return f"❌ Error: {str(e)}"
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# πŸ” RAG QA App\nUpload documents once (zipped), enter your OpenAI key, and ask questions.")
api_key_input = gr.Textbox(label="πŸ”‘ OpenAI API Key", type="password")
question_input = gr.Textbox(label="❓ Your Question")
output_box = gr.Textbox(label="πŸ“„ Answer", lines=10)
ask_button = gr.Button("Ask")
ask_button.click(fn=query_rag, inputs=[api_key_input, question_input], outputs=output_box)