import gradio as gr from datasets import load_dataset from qdrant_client import QdrantClient, models from sentence_transformers import SentenceTransformer import torch # Ensure torch is imported import os import shutil import PyPDF2 from docx import Document import pandas as pd # --- Configuration --- QDRANT_PATH = "./qdrant_db" COLLECTION_NAME = "my_text_collection" MODEL_NAME = 'KaLM-Embedding/KaLM-embedding-multilingual-mini-instruct-v2.5' # Better model for semantic similarity # --- Load Model --- device = "cpu" model = SentenceTransformer(MODEL_NAME, device=device) # --- Qdrant Client and Collection Setup --- qdrant_client = QdrantClient(path=QDRANT_PATH) # Check if the collection already exists collection_exists = False try: collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME) print("Collection already exists.") collection_exists = True except Exception as e: print(f"Collection not found: {e}, creating a new one...") collection_exists = False # If collection doesn't exist, create it and populate with data if not collection_exists: # Load dataset and convert to a simple list format dataset = load_dataset("ag_news", split="test") # Convert dataset to pandas dataframe to properly access the text column df = dataset.to_pandas() data = df['text'].tolist()[:1000] # Get first 1000 text entries # Create the collection with proper vector configuration # Use the correct vector size for the selected model vector_size = model.get_sentence_embedding_dimension() or 768 # Get the actual embedding size of the model, default to 768 for mpnet qdrant_client.create_collection( collection_name=COLLECTION_NAME, vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE), ) # Generate embeddings manually to ensure compatibility print("Generating and indexing embeddings...") embeddings = model.encode(data) # Prepare points for insertion points = [] for i, (text, embedding) in enumerate(zip(data, embeddings)): point = models.PointStruct( id=i, vector=embedding.tolist(), payload={"document": text} ) points.append(point) # Upload points to the collection qdrant_client.upsert( collection_name=COLLECTION_NAME, points=points ) print("Embeddings indexed successfully.") # --- Search Function --- def search_in_qdrant(query): if not query: return "Please enter a search query." # Generate embedding for the query query_embedding = model.encode([query])[0].tolist() hits = qdrant_client.search( collection_name=COLLECTION_NAME, query_vector=query_embedding, limit=5, ) results_text = "" if not hits: return "No results found." for hit in hits: # Check if payload exists and has the document key if hit.payload and 'document' in hit.payload: results_text += f"**Score:** {hit.score:.4f}\n" results_text += f"**Text:** {hit.payload['document']}\n\n" else: results_text += f"**Score:** {hit.score:.4f}\n" results_text += f"**Text:** [No document content available]\n\n" return results_text # --- Upload Function --- def extract_text_from_file(file_path): """Extract text from various file types""" file_extension = file_path.lower().split('.')[-1] if file_extension == 'txt': with open(file_path, 'r', encoding='utf-8') as f: return f.read() elif file_extension == 'pdf': text = "" with open(file_path, 'rb') as f: pdf_reader = PyPDF2.PdfReader(f) for page in pdf_reader.pages: text += page.extract_text() + "\n" return text elif file_extension in ['docx', 'doc']: doc = Document(file_path) text = "" for paragraph in doc.paragraphs: text += paragraph.text + "\n" return text elif file_extension in ['csv', 'xlsx', 'xls']: if file_extension == 'csv': df = pd.read_csv(file_path) else: df = pd.read_excel(file_path) # Convert the entire dataframe to text return df.to_string() else: # Try to read as plain text try: with open(file_path, 'r', encoding='utf-8') as f: return f.read() except UnicodeDecodeError: # If UTF-8 fails, try with different encoding try: with open(file_path, 'r', encoding='latin-1') as f: return f.read() except: return "Could not read file: unsupported format or encoding issue" def upload_to_qdrant(text_content, file_upload=None): if not text_content and not file_upload: return "Please provide text content or upload a file." documents_to_add = [] # Add text content if provided if text_content: documents_to_add.append(text_content) # Process uploaded file if provided if file_upload: try: content = extract_text_from_file(file_upload.name) documents_to_add.append(content) except Exception as e: return f"Error reading file: {str(e)}" if not documents_to_add: return "No content to upload." # Get the next available ID by checking the current max ID in the collection # For simplicity, we'll just get the count of existing records and start from there max_id = 0 # Default to 0 if we can't get the count try: collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME) if hasattr(collection_info, 'points_count') and collection_info.points_count is not None: current_count = collection_info.points_count max_id = current_count # Start from the current count except: max_id = 0 # If there's an error, start with 0 # Generate embeddings for the new documents embeddings = model.encode(documents_to_add) # Prepare points for insertion points = [] for i, (doc, embedding) in enumerate(zip(documents_to_add, embeddings)): point_id = max_id + i + 1 # IDs will be automatically converted as needed by Qdrant point = models.PointStruct( id=point_id, vector=embedding.tolist(), payload={"document": doc} ) points.append(point) # Upload points to the collection qdrant_client.upsert( collection_name=COLLECTION_NAME, points=points ) return f"Successfully added {len(documents_to_add)} document(s) to the collection." # --- Gradio Interface --- with gr.Blocks() as demo: gr.Markdown("# Semantic Search with Qdrant and Gradio") gr.Markdown("Enter a query to search for similar news articles from the AG News dataset.") with gr.Tab("Search"): with gr.Row(): search_input = gr.Textbox(label="Search Query", placeholder="e.g., 'Latest news on space exploration'") search_button = gr.Button("Search") search_output = gr.Markdown() search_button.click(search_in_qdrant, inputs=search_input, outputs=search_output) with gr.Tab("Upload"): with gr.Row(): text_input = gr.Textbox(label="Text Content", placeholder="Enter text to add to the collection", lines=5) with gr.Row(): file_input = gr.File(label="Or Upload a File", file_types=['.txt', '.pdf', '.docx', '.csv', '.xlsx', '.xls', '.md']) upload_button = gr.Button("Upload to Collection") upload_output = gr.Textbox(label="Upload Status", interactive=False) upload_button.click(upload_to_qdrant, inputs=[text_input, file_input], outputs=upload_output) if __name__ == "__main__": demo.launch()