RAG / app.py
francismurray's picture
Add ZeroGPU compatibility
cdba4d2
import numpy as np
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
import spaces
# For embeddings using transformers models
@spaces.GPU
def get_embeddings(texts, model, tokenizer):
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
model_output = model(**encoded_input)
# Mean pooling for sentence embeddings
token_embeddings = model_output.last_hidden_state
attention_mask = encoded_input['attention_mask']
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return embeddings.cpu().numpy()
# Calculate cosine similarity using numpy
def cosine_similarity(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
# Load models
def load_models():
# Embedding model
embed_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
embed_model = AutoModel.from_pretrained("bert-base-uncased")
# Generation model
gen_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
generator = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
return embed_model, embed_tokenizer, generator, gen_tokenizer
# Process uploaded text files
def process_documents(files):
documents = []
for file in files:
with open(file.name, "r", encoding="utf-8") as f:
content = f.read()
# Simple document chunking by paragraphs
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
documents.extend(paragraphs)
return documents
# Create index from documents
def create_index(model, tokenizer, documents):
if not documents:
return None, None
# Create embeddings
embeddings = get_embeddings(documents, model, tokenizer)
return embeddings, documents
# Retrieve relevant documents
def retrieve(query, embeddings, documents, model, tokenizer, top_k=3):
if embeddings is None or documents is None:
return []
# Encode query
query_embedding = get_embeddings([query], model, tokenizer)[0]
# Calculate similarities
similarities = [cosine_similarity(query_embedding, doc_embed) for doc_embed in embeddings]
# Get top k indices
top_indices = np.argsort(similarities)[-top_k:][::-1]
# Return relevant documents
return [documents[idx] for idx in top_indices]
# Generate answer
@spaces.GPU
def generate_answer(query, context, tokenizer, generator):
if not context:
return "No documents have been uploaded yet. Please upload some text files first."
# Combine context
combined_context = " ".join(context)
# Create prompt
prompt = f"Context: {combined_context}\n\nQuestion: {query}\n\nAnswer:"
# Generate answer
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
with torch.no_grad():
outputs = generator.generate(
**inputs,
max_length=256,
num_beams=4,
temperature=0.7,
top_p=0.9,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# RAG pipeline
def rag_pipeline(query, files):
try:
global embed_model, embed_tokenizer, generator, gen_tokenizer, doc_embeddings, indexed_documents
if not files:
return "Please upload some text files first."
# Process documents
documents = process_documents(files)
# Create embeddings
doc_embeddings, indexed_documents = create_index(embed_model, embed_tokenizer, documents)
# Retrieve relevant context
context = retrieve(query, doc_embeddings, indexed_documents, embed_model, embed_tokenizer)
# Generate answer
answer = generate_answer(query, context, gen_tokenizer, generator)
return answer
except Exception as e:
return f"An error occurred: {str(e)}"
# Initialize global variables
embed_model, embed_tokenizer, generator, gen_tokenizer = load_models()
doc_embeddings, indexed_documents = None, None
# Gradio interface
with gr.Blocks(title="RAG Demo") as demo:
gr.Markdown("# πŸ“„πŸ” Retrieval-Augmented Generation (RAG) Demo")
gr.Markdown("Upload text files and ask questions about their content.")
with gr.Row():
with gr.Column(scale=1):
file_output = gr.File(
file_count="multiple",
label="Upload Text Files (.txt)",
file_types=[".txt"],
)
with gr.Column(scale=2):
query_input = gr.Textbox(
label="Your Question",
placeholder="Ask a question about the uploaded documents...",
)
submit_btn = gr.Button("Get Answer", variant="primary")
answer_output = gr.Textbox(label="Answer", lines=10)
submit_btn.click(
rag_pipeline,
inputs=[query_input, file_output],
outputs=answer_output,
)
gr.Markdown(
"""
## How it works
1. Upload one or more `.txt` files
2. Ask a question related to the content
3. The system will:
- Create embeddings using BERT
- Find similar passages using vector similarity
- Retrieve relevant context for your query
- Generate an answer using `flan-t5-base`
Built with πŸ€— Hugging Face's models and Gradio
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch()