Spaces:
Build error
Build error
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| import faiss | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # Load corpus | |
| print("Loading dataset...") | |
| dataset = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="train[:1000]") | |
| corpus = [f"{item['title']}\n{item['text']}" for item in dataset] | |
| # Embedding model | |
| print("Encoding corpus...") | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True) | |
| corpus_embeddings_np = corpus_embeddings.cpu().numpy() | |
| # FAISS index | |
| index = faiss.IndexFlatL2(corpus_embeddings_np.shape[1]) | |
| index.add(corpus_embeddings_np) | |
| # Reranker model | |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| # Generator (choose one: local HF model or OpenAI) | |
| tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") | |
| model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", device_map="auto", torch_dtype=torch.float16) | |
| generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150) | |
| def rag_pipeline(query): | |
| # Embed query | |
| query_embedding = embedder.encode([query], convert_to_tensor=True).cpu().numpy() | |
| # Retrieve top-k from FAISS | |
| D, I = index.search(query_embedding, k=5) | |
| retrieved_docs = [corpus[idx] for idx in I[0]] | |
| # Rerank | |
| rerank_pairs = [[query, doc] for doc in retrieved_docs] | |
| scores = reranker.predict(rerank_pairs) | |
| reranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)] | |
| # Combine for context | |
| context = "\n\n".join(reranked_docs[:2]) | |
| prompt = f"""Answer the following question using the provided context.\n\nContext:\n{context}\n\nQuestion: {query}\nAnswer:""" | |
| # Generate | |
| response = generator(prompt)[0]["generated_text"] | |
| return response.split("Answer:")[-1].strip() | |
| # Gradio UI | |
| iface = gr.Interface(fn=rag_pipeline, | |
| inputs=gr.Textbox(lines=2, placeholder="Ask something..."), | |
| outputs="text", | |
| title="Mini RAG Wikipedia Demo", | |
| description="Retrieval-Augmented Generation on a small Wikipedia subset.") | |
| iface.launch() | |