Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel | |
| import spaces | |
| # For embeddings using transformers models | |
| def get_embeddings(texts, model, tokenizer): | |
| encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt') | |
| with torch.no_grad(): | |
| model_output = model(**encoded_input) | |
| # Mean pooling for sentence embeddings | |
| token_embeddings = model_output.last_hidden_state | |
| attention_mask = encoded_input['attention_mask'] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| return embeddings.cpu().numpy() | |
| # Calculate cosine similarity using numpy | |
| def cosine_similarity(a, b): | |
| return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) | |
| # Load models | |
| def load_models(): | |
| # Embedding model | |
| embed_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| embed_model = AutoModel.from_pretrained("bert-base-uncased") | |
| # Generation model | |
| gen_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") | |
| generator = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base") | |
| return embed_model, embed_tokenizer, generator, gen_tokenizer | |
| # Process uploaded text files | |
| def process_documents(files): | |
| documents = [] | |
| for file in files: | |
| with open(file.name, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| # Simple document chunking by paragraphs | |
| paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()] | |
| documents.extend(paragraphs) | |
| return documents | |
| # Create index from documents | |
| def create_index(model, tokenizer, documents): | |
| if not documents: | |
| return None, None | |
| # Create embeddings | |
| embeddings = get_embeddings(documents, model, tokenizer) | |
| return embeddings, documents | |
| # Retrieve relevant documents | |
| def retrieve(query, embeddings, documents, model, tokenizer, top_k=3): | |
| if embeddings is None or documents is None: | |
| return [] | |
| # Encode query | |
| query_embedding = get_embeddings([query], model, tokenizer)[0] | |
| # Calculate similarities | |
| similarities = [cosine_similarity(query_embedding, doc_embed) for doc_embed in embeddings] | |
| # Get top k indices | |
| top_indices = np.argsort(similarities)[-top_k:][::-1] | |
| # Return relevant documents | |
| return [documents[idx] for idx in top_indices] | |
| # Generate answer | |
| def generate_answer(query, context, tokenizer, generator): | |
| if not context: | |
| return "No documents have been uploaded yet. Please upload some text files first." | |
| # Combine context | |
| combined_context = " ".join(context) | |
| # Create prompt | |
| prompt = f"Context: {combined_context}\n\nQuestion: {query}\n\nAnswer:" | |
| # Generate answer | |
| inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) | |
| with torch.no_grad(): | |
| outputs = generator.generate( | |
| **inputs, | |
| max_length=256, | |
| num_beams=4, | |
| temperature=0.7, | |
| top_p=0.9, | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # RAG pipeline | |
| def rag_pipeline(query, files): | |
| try: | |
| global embed_model, embed_tokenizer, generator, gen_tokenizer, doc_embeddings, indexed_documents | |
| if not files: | |
| return "Please upload some text files first." | |
| # Process documents | |
| documents = process_documents(files) | |
| # Create embeddings | |
| doc_embeddings, indexed_documents = create_index(embed_model, embed_tokenizer, documents) | |
| # Retrieve relevant context | |
| context = retrieve(query, doc_embeddings, indexed_documents, embed_model, embed_tokenizer) | |
| # Generate answer | |
| answer = generate_answer(query, context, gen_tokenizer, generator) | |
| return answer | |
| except Exception as e: | |
| return f"An error occurred: {str(e)}" | |
| # Initialize global variables | |
| embed_model, embed_tokenizer, generator, gen_tokenizer = load_models() | |
| doc_embeddings, indexed_documents = None, None | |
| # Gradio interface | |
| with gr.Blocks(title="RAG Demo") as demo: | |
| gr.Markdown("# ππ Retrieval-Augmented Generation (RAG) Demo") | |
| gr.Markdown("Upload text files and ask questions about their content.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_output = gr.File( | |
| file_count="multiple", | |
| label="Upload Text Files (.txt)", | |
| file_types=[".txt"], | |
| ) | |
| with gr.Column(scale=2): | |
| query_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask a question about the uploaded documents...", | |
| ) | |
| submit_btn = gr.Button("Get Answer", variant="primary") | |
| answer_output = gr.Textbox(label="Answer", lines=10) | |
| submit_btn.click( | |
| rag_pipeline, | |
| inputs=[query_input, file_output], | |
| outputs=answer_output, | |
| ) | |
| gr.Markdown( | |
| """ | |
| ## How it works | |
| 1. Upload one or more `.txt` files | |
| 2. Ask a question related to the content | |
| 3. The system will: | |
| - Create embeddings using BERT | |
| - Find similar passages using vector similarity | |
| - Retrieve relevant context for your query | |
| - Generate an answer using `flan-t5-base` | |
| Built with π€ Hugging Face's models and Gradio | |
| """ | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |