rag_deployment / app.py
ilangodj's picture
Update app.py
9052658 verified
import faiss
import numpy as np
import os
import streamlit as st
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
# Initialize the Sentence-Transformer model for document embeddings
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Example Knowledge Base (can be expanded or replaced with real documents)
#documents = [
# "The Industrial Revolution started in Britain in the 18th century.",
# "Factories and machines revolutionized manufacturing processes.",
# "Steam engines improved transportation and industrial production.",
#"Capitalism and technological innovations shaped modern economies."
#]
# Path to the folder containing your content files
content_file = "Ilango Profile.txt"
# Read all text files in the folder
# Read the document.txt file
with open(content_file, "r") as file:
documents = file.readlines() # This will load each line in the text file as a separate document
# Convert documents to embeddings
document_embeddings = embedding_model.encode(documents)
# Set up FAISS for fast document retrieval
dimension = document_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(document_embeddings))
# Load GPT-2 model and tokenizer from Hugging Face
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Function to retrieve top-k relevant documents
def retrieve_relevant_docs(query, top_k=3):
query_embedding = embedding_model.encode([query]) # Get the embedding of the query
_, idx = index.search(np.array(query_embedding), k=top_k) # Search for the top-k most relevant docs
return "\n\n".join([documents[i] for i in idx[0]])
# Function to generate response using GPT-2
def generate_response(context, query):
prompt = f"Context: {context}\n\nUser Query: {query}\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(inputs['input_ids'], max_length=500, num_return_sequences=1)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Streamlit UI
st.title("Retrieval-Augmented Generation (RAG) System")
st.write("Ask a question, and the system will retrieve relevant documents and generate an answer.")
user_query = st.text_input("Enter your question:")
if user_query:
# Step 1: Retrieve relevant documents
retrieved_docs = retrieve_relevant_docs(user_query)
# Show retrieved documents
st.write("#### Retrieved Documents:")
st.write(retrieved_docs)
# Step 2: Generate response based on retrieved documents
response = generate_response(retrieved_docs, user_query)
# Show the generated answer
st.write("#### Answer:")
st.write(response)