Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| import faiss | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import os | |
| import spaces | |
| # Login automatically if HF_TOKEN is present | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| from huggingface_hub import login | |
| login(token=hf_token) | |
| # Load corpus | |
| print("Loading dataset...") | |
| dataset = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus") | |
| # corpus = [item for item in dataset["passages"]] | |
| # Always clean + use this corpus consistently | |
| corpus = [] | |
| for item in dataset["passages"]: | |
| text = str(item).strip() | |
| if text: | |
| corpus.append(text) | |
| # Embedding model | |
| print("Encoding corpus...") | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True, device='cpu') | |
| corpus_embeddings_np = corpus_embeddings.numpy() | |
| # FAISS index | |
| index = faiss.IndexFlatL2(corpus_embeddings_np.shape[1]) | |
| index.add(corpus_embeddings_np) | |
| # Reranker model | |
| # reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| # Generator (choose one: local HF model or OpenAI) | |
| tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") | |
| model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", torch_dtype=torch.float16) | |
| generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150) | |
| def rag_pipeline(query): | |
| # Embed query | |
| query_embedding = embedder.encode([query], convert_to_tensor=True, device='cpu').numpy() | |
| # Retrieve top-k from FAISS | |
| D, I = index.search(query_embedding, k=5) | |
| retrieved_docs = [corpus[idx] for idx in I[0]] | |
| print("Retrieved indices:", I[0]) | |
| print("Retrieved docs:") | |
| for doc in retrieved_docs: | |
| print("-", repr(doc)) | |
| # # Rerank | |
| # rerank_pairs = [[str(query), str(doc)] for doc in retrieved_docs] | |
| # scores = reranker.predict(rerank_pairs) | |
| # reranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)] | |
| # Combine for context | |
| context = "\n\n".join(retrieved_docs[:2]) | |
| prompt = f"""Answer the following question using the provided context.\n\nContext:\n{context}\n\nQuestion: {query}\nAnswer:""" | |
| # Generate | |
| response = generator(prompt)[0]["generated_text"] | |
| return response.split("Answer:")[-1].strip() | |
| # Gradio UI | |
| iface = gr.Interface(fn=rag_pipeline, | |
| inputs=gr.Textbox(lines=2, placeholder="How fast is a penguin?"), | |
| outputs="text", | |
| title="Mini RAG Wikipedia Demo", | |
| description="Retrieval-Augmented Generation on a small Wikipedia subset.") | |
| iface.launch() | |