Spaces:
Sleeping
Sleeping
| # Step 1: Install required packages | |
| import configparser | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import torch | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| class TextSearchEngine(): | |
| def __init__(self, embeddings_csv_path): | |
| self.embeddings_csv_path = embeddings_csv_path | |
| torch.classes.__path__ = [] | |
| def load_data_and_model(self): | |
| # Load a sample dataset (Stanford Movie Review Dataset) | |
| dataset = load_dataset('imdb', split='train[:1000]') # Using first 1000 examples | |
| df = pd.DataFrame(dataset)[['text', 'label']] | |
| # Load a small model that fits in 4GB VRAM | |
| model = SentenceTransformer('all-MiniLM-L6-v2') # 384-dimensional embeddings | |
| return df, model | |
| def generate_embeddings(self, df, model, overwrite=False): | |
| if ((not os.path.exists(self.embeddings_csv_path)) or overwrite): | |
| texts = df['text'].tolist() | |
| # Generate embeddings in batches for efficiency | |
| embeddings = model.encode(texts, batch_size=32, show_progress_bar=True) | |
| # Convert numpy array to string representation for CSV storage | |
| df['embedding'] = [','.join(map(str, emb)) for emb in embeddings] | |
| df.to_csv(self.embeddings_csv_path, index=False) | |
| return df | |
| def semantic_search(self, query, model, top_k=5): | |
| # Load embeddings from CSV | |
| df = pd.read_csv(self.embeddings_csv_path) | |
| # Convert string embeddings back to numpy arrays | |
| df['embedding'] = df['embedding'].apply(lambda x: np.fromstring(x, sep=',')) | |
| # Encode query | |
| query_embedding = model.encode([query]) | |
| # Calculate similarities | |
| embeddings_matrix = np.vstack(df['embedding'].values) | |
| similarities = cosine_similarity(query_embedding, embeddings_matrix).flatten() | |
| # Create and sort results | |
| df['similarity'] = similarities | |
| results = df.sort_values('similarity', ascending=False).head(top_k) | |
| return results[['text', 'similarity', 'label']] | |
| # Execution flow | |
| if __name__ == "__main__": | |
| config = configparser.ConfigParser() | |
| config.read('config.cfg') | |
| embeddings_csv_path = config['SERVER']['embeddings_csv_path'] | |
| text_search_engine_manager = TextSearchEngine(embeddings_csv_path) | |
| # Generate and save embeddings (run once) | |
| df, model = text_search_engine_manager.load_data_and_model() | |
| text_search_engine_manager.generate_embeddings(df, model, overwrite=False) | |
| # Example search | |
| query = config['TEST']['query'] | |
| results = text_search_engine_manager.semantic_search(query, model) | |
| print('Results -> ', results) |