Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModel | |
| from sentence_transformers import CrossEncoder | |
| import torch | |
| import torch.nn.functional as F | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| # --- Constants --- | |
| TOP_K_FINAL = 3 | |
| RETRIEVAL_CANDIDATE_COUNT = 20 | |
| # --- 1. SETUP: Load all necessary models --- | |
| print("Loading Qwen3 Embedding Model (Retriever)...") | |
| # Using the model you specified | |
| embedding_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B") | |
| embedding_model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B") | |
| print("Qwen3 Embedding Model loaded.") | |
| print("Loading Reranker model (Cross-Encoder)...") | |
| reranker_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
| print("Reranker model loaded.") | |
| # --- 2. CORE FUNCTIONS --- | |
| def get_qwen_embeddings_batch(texts): | |
| """ | |
| A new function to get embeddings for a BATCH of texts using Qwen3. | |
| This is much more efficient than one-by-one. | |
| """ | |
| # Important: `padding=True` and `truncation=True` are key for batching | |
| inputs = embedding_tokenizer(texts, padding=True, truncation=True, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = embedding_model(**inputs) | |
| # Extract the [CLS] token's embedding for each text in the batch | |
| embeddings = outputs.last_hidden_state[:, 0, :] | |
| return embeddings | |
| def process_and_index_document(source_text): | |
| """ | |
| This function is triggered by the 'Index Document' button. | |
| It chunks the text, creates embeddings, and stores them. | |
| """ | |
| if not source_text or not source_text.strip(): | |
| # Update the UI to show an error and hide the search bar | |
| return None, None, "❌ Error: Please provide some source text.", gr.update(visible=False) | |
| print("--- Starting document processing ---") | |
| # a. Chunk the document | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, chunk_overlap=50, | |
| length_function=len, separators=["\n\n", "\n", " ", ""], | |
| ) | |
| chunks = text_splitter.split_text(source_text) | |
| print(f"Document split into {len(chunks)} chunks.") | |
| # b. Vectorize the chunks using Qwen3 | |
| print("Vectorizing chunks with Qwen3... (This might take a moment)") | |
| embeddings = get_qwen_embeddings_batch(chunks) | |
| print("Vectorization complete. Shape:", embeddings.shape) | |
| # c. Return the processed data and update UI | |
| success_message = f"✅ Document indexed successfully into {len(chunks)} chunks." | |
| # The last return value makes the search group visible | |
| return chunks, embeddings, success_message, gr.update(visible=True) | |
| def search_and_rerank(user_query, document_chunks, document_embeddings): | |
| """ | |
| The main search logic (retrieval + reranking). | |
| This function now takes the chunks and embeddings from the session state. | |
| """ | |
| if not user_query or not user_query.strip(): | |
| return [""] * (TOP_K_FINAL * 2) | |
| if document_chunks is None: | |
| return ["Please index a document first."] * (TOP_K_FINAL * 2) | |
| # --- STAGE 1: RETRIEVAL --- | |
| query_embedding = get_qwen_embeddings_batch([user_query]) # Embed the single query | |
| # Use PyTorch's cosine similarity | |
| similarities = F.cosine_similarity(query_embedding, document_embeddings) | |
| # Get the top candidates | |
| top_retrieval_indices = torch.topk(similarities, k=min(RETRIEVAL_CANDIDATE_COUNT, len(document_chunks))).indices | |
| candidate_chunks = [document_chunks[idx] for idx in top_retrieval_indices] | |
| # --- STAGE 2: RERANKING --- | |
| reranker_input_pairs = [[user_query, chunk] for chunk in candidate_chunks] | |
| rerank_scores = reranker_model.predict(reranker_input_pairs) | |
| reranked_results = sorted(zip(rerank_scores, candidate_chunks), key=lambda x: x[0], reverse=True) | |
| # --- Prepare final output --- | |
| outputs = [] | |
| for score, chunk in reranked_results[:TOP_K_FINAL]: | |
| outputs.append(f"Rerank Score: {score:.4f}") | |
| outputs.append(chunk) | |
| while len(outputs) < TOP_K_FINAL * 2: | |
| outputs.extend(["", ""]) | |
| return outputs | |
| # --- 3. GRADIO USER INTERFACE --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
| gr.Markdown("# 🧠 Dynamic RAG with Qwen3 + Reranker") | |
| gr.Markdown("**Step 1:** Paste your source text below and click 'Index Document'.\n" | |
| "**Step 2:** Once indexed, use the search bar to ask questions.") | |
| # We use gr.State to hold session-specific data (chunks and embeddings) | |
| chunks_state = gr.State() | |
| embeddings_state = gr.State() | |
| with gr.Row(): | |
| source_document_input = gr.Textbox( | |
| label="Source Document Text", | |
| placeholder="Paste the full text of your document here...", | |
| lines=15, | |
| scale=2 | |
| ) | |
| index_button = gr.Button("Index Document 🚀") | |
| status_display = gr.Markdown("Status: Ready to index a document.") | |
| # The search UI is hidden until indexing is complete | |
| with gr.Column(visible=False) as search_ui_group: | |
| gr.Markdown("---") | |
| gr.Markdown("### Step 2: Search Your Document") | |
| query_input = gr.Textbox( | |
| label="Your Question or Topic", | |
| placeholder="e.g., What is the main goal of the project?", | |
| lines=1 | |
| ) | |
| output_components = [] | |
| for i in range(TOP_K_FINAL): | |
| with gr.Group(): | |
| score = gr.Textbox(label=f"Result {i+1} Score", interactive=False) | |
| chunk_text = gr.Textbox(label="Retrieved Chunk", interactive=False, lines=4) | |
| output_components.extend([score, chunk_text]) | |
| # --- Connect UI components to functions --- | |
| # When the index button is clicked... | |
| index_button.click( | |
| fn=process_and_index_document, | |
| inputs=[source_document_input], | |
| # The outputs are the state variables, the status message, and the search UI group | |
| outputs=[chunks_state, embeddings_state, status_display, search_ui_group] | |
| ) | |
| # When the query input changes (live search)... | |
| query_input.change( | |
| fn=search_and_rerank, | |
| # Inputs must include the state variables | |
| inputs=[query_input, chunks_state, embeddings_state], | |
| outputs=output_components | |
| ) | |
| if __name__ == "__main__": | |
| print("\nInterface is launching... Go to the printed URL.") | |
| iface.launch() |