Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import vecmini | |
| import pypdf | |
| from transformers import AutoTokenizer, AutoModel | |
| #rando | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float32 | |
| print("Loading Sentence Encoder (Bi-Encoder Only)...") | |
| embed_id = "sentence-transformers/all-MiniLM-L6-v2" | |
| embed_tokenizer = AutoTokenizer.from_pretrained(embed_id) | |
| embed_model = AutoModel.from_pretrained(embed_id).to(device).to(torch_dtype) | |
| global_chunks = [] | |
| db = None | |
| global_nlist = 1 | |
| def mean_pooling(model_output, attention_mask): | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| def encode_texts(texts): | |
| encoded_input = embed_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| model_output = embed_model(**encoded_input) | |
| embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) | |
| return torch.nn.functional.normalize(embeddings, p=2, dim=1).cpu().numpy().astype(np.float32) | |
| def process_pdf(file_obj): | |
| global global_chunks, db, global_nlist | |
| if file_obj is None: | |
| return "Error: No file uploaded." | |
| try: | |
| reader = pypdf.PdfReader(file_obj.name) | |
| text = "" | |
| for page in reader.pages: | |
| extracted = page.extract_text() | |
| if extracted: | |
| text += extracted + " " | |
| except Exception as e: | |
| return f"Failed to read PDF: {str(e)}" | |
| if not text.strip(): | |
| return "Error: Could not extract any readable text from this PDF." | |
| chunk_size = 200 | |
| words = text.split() | |
| global_chunks = [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)] | |
| embeddings = encode_texts(global_chunks) | |
| d = embeddings.shape[1] | |
| nb = embeddings.shape[0] | |
| global_nlist = max(1, int(nb / 4)) | |
| db = vecmini.IndexIVF(d, global_nlist) | |
| db.train(nb, embeddings) | |
| db.add(nb, embeddings, np.arange(nb, dtype=np.uint64)) | |
| return f"Success! Extracted {nb} chunks from the PDF and built vecmini index." | |
| def retrieve_chunks(query, top_k): | |
| if db is None or not global_chunks: | |
| return "Please upload and process a PDF first." | |
| if not query.strip(): | |
| return "Please enter a search query." | |
| query_emb = encode_texts([query]) | |
| fetch_k = min(int(top_k), len(global_chunks)) | |
| nprobe = max(1, int(global_nlist / 2)) | |
| distances, labels = db.search(1, query_emb, k=fetch_k, nprobe=nprobe, bitmask=None) | |
| retrieved_indices = [idx for idx in labels[0] if idx < len(global_chunks)] | |
| output_text = f"### Top {len(retrieved_indices)} Results for: *'{query}'*\n\n" | |
| for i, idx in enumerate(retrieved_indices): | |
| dist = distances[0][i] | |
| chunk_text = global_chunks[idx] | |
| output_text += f"**Result {i+1}** | Vector Distance: `{dist:.4f}` | Chunk ID: `{idx}`\n" | |
| output_text += f"> {chunk_text}\n\n---\n\n" | |
| return output_text | |
| with gr.Blocks(title="Vecmini Visualizer") as demo: | |
| gr.Markdown("# Vecmini: PDF Raw Retrieval Tester") | |
| gr.Markdown("Upload a PDF, build the index, and see exactly what `vecmini` returns for your queries.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| pdf_input = gr.File(label="Upload PDF Document", file_types=[".pdf"]) | |
| process_btn = gr.Button("Build Vecmini Index", variant="primary") | |
| status_out = gr.Textbox(label="Index Status", interactive=False) | |
| with gr.Column(): | |
| query_input = gr.Textbox(label="Search Query") | |
| k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of chunks to retrieve (K)") | |
| search_btn = gr.Button("Search Vecmini") | |
| results_out = gr.Markdown(label="Retrieved Chunks") | |
| process_btn.click(fn=process_pdf, inputs=pdf_input, outputs=status_out) | |
| search_btn.click(fn=retrieve_chunks, inputs=[query_input, k_slider], outputs=results_out) | |
| query_input.submit(fn=retrieve_chunks, inputs=[query_input, k_slider], outputs=results_out) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |