File size: 6,386 Bytes
306e5e3
 
73f3dc0
306e5e3
4c8f98f
73f3dc0
306e5e3
73f3dc0
 
 
306e5e3
73f3dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306e5e3
73f3dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c8f98f
73f3dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c8f98f
73f3dc0
 
 
 
 
 
 
4c8f98f
73f3dc0
 
 
 
4c8f98f
73f3dc0
4c8f98f
73f3dc0
 
 
 
306e5e3
73f3dc0
 
 
306e5e3
73f3dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306e5e3
 
73f3dc0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import CrossEncoder
import torch
import torch.nn.functional as F
from langchain.text_splitter import RecursiveCharacterTextSplitter

# --- Constants ---
TOP_K_FINAL = 3
RETRIEVAL_CANDIDATE_COUNT = 20

# --- 1. SETUP: Load all necessary models ---

print("Loading Qwen3 Embedding Model (Retriever)...")
# Using the model you specified
embedding_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
embedding_model = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B")
print("Qwen3 Embedding Model loaded.")

print("Loading Reranker model (Cross-Encoder)...")
reranker_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
print("Reranker model loaded.")


# --- 2. CORE FUNCTIONS ---

def get_qwen_embeddings_batch(texts):
    """
    A new function to get embeddings for a BATCH of texts using Qwen3.
    This is much more efficient than one-by-one.
    """
    # Important: `padding=True` and `truncation=True` are key for batching
    inputs = embedding_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        outputs = embedding_model(**inputs)
        # Extract the [CLS] token's embedding for each text in the batch
        embeddings = outputs.last_hidden_state[:, 0, :]
    return embeddings

def process_and_index_document(source_text):
    """
    This function is triggered by the 'Index Document' button.
    It chunks the text, creates embeddings, and stores them.
    """
    if not source_text or not source_text.strip():
        # Update the UI to show an error and hide the search bar
        return None, None, "❌ Error: Please provide some source text.", gr.update(visible=False)

    print("--- Starting document processing ---")
    
    # a. Chunk the document
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500, chunk_overlap=50,
        length_function=len, separators=["\n\n", "\n", " ", ""],
    )
    chunks = text_splitter.split_text(source_text)
    print(f"Document split into {len(chunks)} chunks.")
    
    # b. Vectorize the chunks using Qwen3
    print("Vectorizing chunks with Qwen3... (This might take a moment)")
    embeddings = get_qwen_embeddings_batch(chunks)
    print("Vectorization complete. Shape:", embeddings.shape)
    
    # c. Return the processed data and update UI
    success_message = f"✅ Document indexed successfully into {len(chunks)} chunks."
    # The last return value makes the search group visible
    return chunks, embeddings, success_message, gr.update(visible=True)


def search_and_rerank(user_query, document_chunks, document_embeddings):
    """
    The main search logic (retrieval + reranking).
    This function now takes the chunks and embeddings from the session state.
    """
    if not user_query or not user_query.strip():
        return [""] * (TOP_K_FINAL * 2)
        
    if document_chunks is None:
        return ["Please index a document first."] * (TOP_K_FINAL * 2)

    # --- STAGE 1: RETRIEVAL ---
    query_embedding = get_qwen_embeddings_batch([user_query]) # Embed the single query
    
    # Use PyTorch's cosine similarity
    similarities = F.cosine_similarity(query_embedding, document_embeddings)
    
    # Get the top candidates
    top_retrieval_indices = torch.topk(similarities, k=min(RETRIEVAL_CANDIDATE_COUNT, len(document_chunks))).indices
    candidate_chunks = [document_chunks[idx] for idx in top_retrieval_indices]

    # --- STAGE 2: RERANKING ---
    reranker_input_pairs = [[user_query, chunk] for chunk in candidate_chunks]
    rerank_scores = reranker_model.predict(reranker_input_pairs)
    
    reranked_results = sorted(zip(rerank_scores, candidate_chunks), key=lambda x: x[0], reverse=True)

    # --- Prepare final output ---
    outputs = []
    for score, chunk in reranked_results[:TOP_K_FINAL]:
        outputs.append(f"Rerank Score: {score:.4f}")
        outputs.append(chunk)

    while len(outputs) < TOP_K_FINAL * 2:
        outputs.extend(["", ""])
        
    return outputs

# --- 3. GRADIO USER INTERFACE ---

with gr.Blocks(theme=gr.themes.Soft()) as iface:
    gr.Markdown("# 🧠 Dynamic RAG with Qwen3 + Reranker")
    gr.Markdown("**Step 1:** Paste your source text below and click 'Index Document'.\n"
                "**Step 2:** Once indexed, use the search bar to ask questions.")

    # We use gr.State to hold session-specific data (chunks and embeddings)
    chunks_state = gr.State()
    embeddings_state = gr.State()

    with gr.Row():
        source_document_input = gr.Textbox(
            label="Source Document Text",
            placeholder="Paste the full text of your document here...",
            lines=15,
            scale=2
        )
    
    index_button = gr.Button("Index Document 🚀")
    status_display = gr.Markdown("Status: Ready to index a document.")
    
    # The search UI is hidden until indexing is complete
    with gr.Column(visible=False) as search_ui_group:
        gr.Markdown("---")
        gr.Markdown("### Step 2: Search Your Document")
        query_input = gr.Textbox(
            label="Your Question or Topic",
            placeholder="e.g., What is the main goal of the project?",
            lines=1
        )

        output_components = []
        for i in range(TOP_K_FINAL):
            with gr.Group():
                score = gr.Textbox(label=f"Result {i+1} Score", interactive=False)
                chunk_text = gr.Textbox(label="Retrieved Chunk", interactive=False, lines=4)
            output_components.extend([score, chunk_text])
    
    # --- Connect UI components to functions ---
    
    # When the index button is clicked...
    index_button.click(
        fn=process_and_index_document,
        inputs=[source_document_input],
        # The outputs are the state variables, the status message, and the search UI group
        outputs=[chunks_state, embeddings_state, status_display, search_ui_group]
    )
    
    # When the query input changes (live search)...
    query_input.change(
        fn=search_and_rerank,
        # Inputs must include the state variables
        inputs=[query_input, chunks_state, embeddings_state],
        outputs=output_components
    )

if __name__ == "__main__":
    print("\nInterface is launching... Go to the printed URL.")
    iface.launch()