Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import faiss | |
| import pickle | |
| import os | |
| import argparse | |
| from sentence_transformers import SentenceTransformer | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Embed labeled reviews into FAISS vectorstore.") | |
| parser.add_argument('--input_csv', type=str, default='data/processed/amazon_labeled_reviews.csv', help='CSV with texts and predicted_aspects') | |
| parser.add_argument('--out_dir', type=str, default='vectorstore', help='Directory to output FAISS index') | |
| parser.add_argument('--text_column', type=str, default='reviewDocument', help='Name of the column containing the review text') | |
| parser.add_argument('--model_name', type=str, default='all-MiniLM-L6-v2', help='Sentence-transformers model') | |
| args = parser.parse_args() | |
| if not os.path.exists(args.input_csv): | |
| print(f"Error: {args.input_csv} not found. Please run src/prepare_amazon_data.py first.") | |
| return | |
| print(f"Loading data from {args.input_csv}...") | |
| df = pd.read_csv(args.input_csv) | |
| if args.text_column not in df.columns: | |
| print(f"Error: Column {args.text_column} not found in {args.input_csv}.") | |
| print(f"Available columns: {df.columns.tolist()}") | |
| return | |
| # Filter out empty texts | |
| df = df[df[args.text_column].notna() & (df[args.text_column] != '')] | |
| texts = df[args.text_column].tolist() | |
| print(f"Loading embedding model: {args.model_name}...") | |
| embedder = SentenceTransformer(args.model_name) | |
| print(f"Encoding {len(texts)} reviews into dense vectors...") | |
| # encode handles batching automatically | |
| embeddings = embedder.encode(texts, show_progress_bar=True, convert_to_numpy=True) | |
| # Initialize FAISS Index (L2 normalized Flat Inner Product = Cosine Similarity) | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dimension) | |
| # Normalize before adding to FAISS for cosine similarity | |
| print("Normalizing vectors and adding to FAISS index...") | |
| faiss.normalize_L2(embeddings) | |
| index.add(embeddings) | |
| os.makedirs(args.out_dir, exist_ok=True) | |
| # Save FAISS index | |
| index_path = os.path.join(args.out_dir, 'reviews.index') | |
| faiss.write_index(index, index_path) | |
| # Save the dataframe mapping (so we can retrieve the actual text + labels later) | |
| df_path = os.path.join(args.out_dir, 'reviews_metadata.pkl') | |
| with open(df_path, 'wb') as f: | |
| pickle.dump(df, f) | |
| print(f"\nSuccessfully embedded {len(texts)} reviews and saved to '{args.out_dir}'!") | |
| if __name__ == "__main__": | |
| main() | |