Spaces:
Sleeping
Sleeping
File size: 6,386 Bytes
306e5e3 73f3dc0 306e5e3 4c8f98f 73f3dc0 306e5e3 73f3dc0 306e5e3 73f3dc0 306e5e3 73f3dc0 4c8f98f 73f3dc0 4c8f98f 73f3dc0 4c8f98f 73f3dc0 4c8f98f 73f3dc0 4c8f98f 73f3dc0 306e5e3 73f3dc0 306e5e3 73f3dc0 306e5e3 73f3dc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import CrossEncoder
import torch
import torch.nn.functional as F
from langchain.text_splitter import RecursiveCharacterTextSplitter
# --- Constants ---
TOP_K_FINAL = 3
RETRIEVAL_CANDIDATE_COUNT = 20
# --- 1. SETUP: Load all necessary models ---
print("Loading Qwen3 Embedding Model (Retriever)...")
# Using the model you specified
embedding_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
embedding_model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
print("Qwen3 Embedding Model loaded.")
print("Loading Reranker model (Cross-Encoder)...")
reranker_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
print("Reranker model loaded.")
# --- 2. CORE FUNCTIONS ---
def get_qwen_embeddings_batch(texts):
"""
A new function to get embeddings for a BATCH of texts using Qwen3.
This is much more efficient than one-by-one.
"""
# Important: `padding=True` and `truncation=True` are key for batching
inputs = embedding_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
outputs = embedding_model(**inputs)
# Extract the [CLS] token's embedding for each text in the batch
embeddings = outputs.last_hidden_state[:, 0, :]
return embeddings
def process_and_index_document(source_text):
"""
This function is triggered by the 'Index Document' button.
It chunks the text, creates embeddings, and stores them.
"""
if not source_text or not source_text.strip():
# Update the UI to show an error and hide the search bar
return None, None, "❌ Error: Please provide some source text.", gr.update(visible=False)
print("--- Starting document processing ---")
# a. Chunk the document
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_overlap=50,
length_function=len, separators=["\n\n", "\n", " ", ""],
)
chunks = text_splitter.split_text(source_text)
print(f"Document split into {len(chunks)} chunks.")
# b. Vectorize the chunks using Qwen3
print("Vectorizing chunks with Qwen3... (This might take a moment)")
embeddings = get_qwen_embeddings_batch(chunks)
print("Vectorization complete. Shape:", embeddings.shape)
# c. Return the processed data and update UI
success_message = f"✅ Document indexed successfully into {len(chunks)} chunks."
# The last return value makes the search group visible
return chunks, embeddings, success_message, gr.update(visible=True)
def search_and_rerank(user_query, document_chunks, document_embeddings):
"""
The main search logic (retrieval + reranking).
This function now takes the chunks and embeddings from the session state.
"""
if not user_query or not user_query.strip():
return [""] * (TOP_K_FINAL * 2)
if document_chunks is None:
return ["Please index a document first."] * (TOP_K_FINAL * 2)
# --- STAGE 1: RETRIEVAL ---
query_embedding = get_qwen_embeddings_batch([user_query]) # Embed the single query
# Use PyTorch's cosine similarity
similarities = F.cosine_similarity(query_embedding, document_embeddings)
# Get the top candidates
top_retrieval_indices = torch.topk(similarities, k=min(RETRIEVAL_CANDIDATE_COUNT, len(document_chunks))).indices
candidate_chunks = [document_chunks[idx] for idx in top_retrieval_indices]
# --- STAGE 2: RERANKING ---
reranker_input_pairs = [[user_query, chunk] for chunk in candidate_chunks]
rerank_scores = reranker_model.predict(reranker_input_pairs)
reranked_results = sorted(zip(rerank_scores, candidate_chunks), key=lambda x: x[0], reverse=True)
# --- Prepare final output ---
outputs = []
for score, chunk in reranked_results[:TOP_K_FINAL]:
outputs.append(f"Rerank Score: {score:.4f}")
outputs.append(chunk)
while len(outputs) < TOP_K_FINAL * 2:
outputs.extend(["", ""])
return outputs
# --- 3. GRADIO USER INTERFACE ---
with gr.Blocks(theme=gr.themes.Soft()) as iface:
gr.Markdown("# 🧠 Dynamic RAG with Qwen3 + Reranker")
gr.Markdown("**Step 1:** Paste your source text below and click 'Index Document'.\n"
"**Step 2:** Once indexed, use the search bar to ask questions.")
# We use gr.State to hold session-specific data (chunks and embeddings)
chunks_state = gr.State()
embeddings_state = gr.State()
with gr.Row():
source_document_input = gr.Textbox(
label="Source Document Text",
placeholder="Paste the full text of your document here...",
lines=15,
scale=2
)
index_button = gr.Button("Index Document 🚀")
status_display = gr.Markdown("Status: Ready to index a document.")
# The search UI is hidden until indexing is complete
with gr.Column(visible=False) as search_ui_group:
gr.Markdown("---")
gr.Markdown("### Step 2: Search Your Document")
query_input = gr.Textbox(
label="Your Question or Topic",
placeholder="e.g., What is the main goal of the project?",
lines=1
)
output_components = []
for i in range(TOP_K_FINAL):
with gr.Group():
score = gr.Textbox(label=f"Result {i+1} Score", interactive=False)
chunk_text = gr.Textbox(label="Retrieved Chunk", interactive=False, lines=4)
output_components.extend([score, chunk_text])
# --- Connect UI components to functions ---
# When the index button is clicked...
index_button.click(
fn=process_and_index_document,
inputs=[source_document_input],
# The outputs are the state variables, the status message, and the search UI group
outputs=[chunks_state, embeddings_state, status_display, search_ui_group]
)
# When the query input changes (live search)...
query_input.change(
fn=search_and_rerank,
# Inputs must include the state variables
inputs=[query_input, chunks_state, embeddings_state],
outputs=output_components
)
if __name__ == "__main__":
print("\nInterface is launching... Go to the printed URL.")
iface.launch() |