import numpy as np import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel import spaces # For embeddings using transformers models @spaces.GPU def get_embeddings(texts, model, tokenizer): encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') with torch.no_grad(): model_output = model(**encoded_input) # Mean pooling for sentence embeddings token_embeddings = model_output.last_hidden_state attention_mask = encoded_input['attention_mask'] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) return embeddings.cpu().numpy() # Calculate cosine similarity using numpy def cosine_similarity(a, b): return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) # Load models def load_models(): # Embedding model embed_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") embed_model = AutoModel.from_pretrained("bert-base-uncased") # Generation model gen_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") generator = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base") return embed_model, embed_tokenizer, generator, gen_tokenizer # Process uploaded text files def process_documents(files): documents = [] for file in files: with open(file.name, "r", encoding="utf-8") as f: content = f.read() # Simple document chunking by paragraphs paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()] documents.extend(paragraphs) return documents # Create index from documents def create_index(model, tokenizer, documents): if not documents: return None, None # Create embeddings embeddings = get_embeddings(documents, model, tokenizer) return embeddings, documents # Retrieve relevant documents def retrieve(query, embeddings, documents, model, tokenizer, top_k=3): if embeddings is None or documents is None: return [] # Encode query query_embedding = get_embeddings([query], model, tokenizer)[0] # Calculate similarities similarities = [cosine_similarity(query_embedding, doc_embed) for doc_embed in embeddings] # Get top k indices top_indices = np.argsort(similarities)[-top_k:][::-1] # Return relevant documents return [documents[idx] for idx in top_indices] # Generate answer @spaces.GPU def generate_answer(query, context, tokenizer, generator): if not context: return "No documents have been uploaded yet. Please upload some text files first." # Combine context combined_context = " ".join(context) # Create prompt prompt = f"Context: {combined_context}\n\nQuestion: {query}\n\nAnswer:" # Generate answer inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) with torch.no_grad(): outputs = generator.generate( **inputs, max_length=256, num_beams=4, temperature=0.7, top_p=0.9, ) return tokenizer.decode(outputs[0], skip_special_tokens=True) # RAG pipeline def rag_pipeline(query, files): try: global embed_model, embed_tokenizer, generator, gen_tokenizer, doc_embeddings, indexed_documents if not files: return "Please upload some text files first." # Process documents documents = process_documents(files) # Create embeddings doc_embeddings, indexed_documents = create_index(embed_model, embed_tokenizer, documents) # Retrieve relevant context context = retrieve(query, doc_embeddings, indexed_documents, embed_model, embed_tokenizer) # Generate answer answer = generate_answer(query, context, gen_tokenizer, generator) return answer except Exception as e: return f"An error occurred: {str(e)}" # Initialize global variables embed_model, embed_tokenizer, generator, gen_tokenizer = load_models() doc_embeddings, indexed_documents = None, None # Gradio interface with gr.Blocks(title="RAG Demo") as demo: gr.Markdown("# 📄🔍 Retrieval-Augmented Generation (RAG) Demo") gr.Markdown("Upload text files and ask questions about their content.") with gr.Row(): with gr.Column(scale=1): file_output = gr.File( file_count="multiple", label="Upload Text Files (.txt)", file_types=[".txt"], ) with gr.Column(scale=2): query_input = gr.Textbox( label="Your Question", placeholder="Ask a question about the uploaded documents...", ) submit_btn = gr.Button("Get Answer", variant="primary") answer_output = gr.Textbox(label="Answer", lines=10) submit_btn.click( rag_pipeline, inputs=[query_input, file_output], outputs=answer_output, ) gr.Markdown( """ ## How it works 1. Upload one or more `.txt` files 2. Ask a question related to the content 3. The system will: - Create embeddings using BERT - Find similar passages using vector similarity - Retrieve relevant context for your query - Generate an answer using `flan-t5-base` Built with 🤗 Hugging Face's models and Gradio """ ) # Launch the app if __name__ == "__main__": demo.launch()