ramayan_rag / app.py
anshumanpatil's picture
RAG referances changed
e8586c9
import streamlit as st
import pandas as pd
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from sentence_transformers import SentenceTransformer
from langchain_community.document_loaders import DirectoryLoader, TextLoader
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
docs = []
db = None
extracted_text = None
# ------------------------------
# Load Model for pretraining
# ------------------------------
@st.cache_resource
def load_model():
# model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return pipeline("text-generation", model=model, tokenizer=tokenizer)
@st.cache_resource
def extract_text():
uploaded_data_path = "./ramayan"
loader = DirectoryLoader(
path=uploaded_data_path,
glob="**/*.txt",
loader_cls=TextLoader,
recursive=True
)
documents = loader.load()
return "\n".join([doc.page_content for doc in documents])
def build_faiss(_docs):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
return FAISS.from_documents(_docs, embeddings)
with st.spinner("πŸ”„ Loading Model..."):
generator = load_model()
with st.spinner("πŸ”„ Loading Knowldge Base..."):
extracted_text = extract_text()
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = [Document(page_content=chunk) for chunk in splitter.split_text(extracted_text)]
db = build_faiss(docs)
st.success("βœ… Knowledge Base ready! From :- Ramayan")
# ------------------------------
# Title
# ------------------------------
st.title("πŸ“š RAG For Ramayan")
st.markdown("This app uses a local LLM model to answer questions about Ramayan using RAG (Retrieval Augmented Generation).")
with st.form(key='my_form'):
query = st.text_input("πŸ’¬ Ask a question about Ramayan(Required)", placeholder="Rishyasring's Departure")
max_new_tokens_model = st.slider("Max New Tokens (Optional):", min_value=50, max_value=500, value=150, step=25)
temperature_model = st.slider("Temperature (Optional):", min_value=0.0, max_value=0.9, value=0.5, step=0.1)
submit_button = st.form_submit_button("Submit")
if submit_button:
if query and db and extracted_text and len(docs) > 0:
retriever = db.as_retriever(search_kwargs={"k": 3})
retrieved_docs = retriever.get_relevant_documents(query)
context = "\n".join([doc.page_content for doc in retrieved_docs])
with st.spinner("πŸ€” Generating answer..."):
result = generator(
f"Context:\n{context}\n\nQuestion: {query}\nAnswer:",
max_new_tokens=max_new_tokens_model,
temperature=temperature_model,
top_p=0.9
)
generated = result[0]["generated_text"]
answer_only = generated.split("Answer:")[-1].strip()
st.write("πŸ“ Answer:", answer_only)