Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import os | |
| import argparse | |
| from tqdm import tqdm | |
| from inference import ABSAPredictor | |
| import json | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Auto-label an Amazon dataset using the fine-tuned ABSA model.") | |
| parser.add_argument('--input_csv', type=str, default='data/raw/amazon_reviews.csv', help='Raw Amazon reviews CSV') | |
| parser.add_argument('--output_csv', type=str, default='data/processed/amazon_labeled_reviews.csv', help='Output CSV with labeled aspects') | |
| parser.add_argument('--text_column', type=str, default='reviewDocument', help='Name of the column containing the review text') | |
| parser.add_argument('--sample_size', type=int, default=500, help='Number of reviews to process (for speed/demo)') | |
| args = parser.parse_args() | |
| if not os.path.exists(args.input_csv): | |
| print(f"Error: {args.input_csv} not found.") | |
| print("Please download a sample Amazon Electronics/Laptop reviews CSV and place it as 'data/raw/amazon_reviews.csv'") | |
| print("The CSV should have a column containing the review text (default name: 'reviewDocument').") | |
| # Create a tiny dummy file so the rest of the project can be built and tested | |
| print("\nCreating a small dummy amazon_reviews.csv file for testing purposes...") | |
| os.makedirs(os.path.dirname(args.input_csv), exist_ok=True) | |
| dummy_df = pd.DataFrame({ | |
| args.text_column: [ | |
| "The laptop is great but the battery life is terrible.", | |
| "Amazing screen, very bright. Shipping was fast.", | |
| "Customer service was completely unhelpful when my screen broke.", | |
| "Way too expensive for what you get. The keyboard feels cheap.", | |
| "I love this computer, it runs all my games smoothly." | |
| ] | |
| }) | |
| dummy_df.to_csv(args.input_csv, index=False) | |
| print(f"Created {args.input_csv} with 5 sample reviews.") | |
| print(f"\nLoading reviews from {args.input_csv}...") | |
| df = pd.read_csv(args.input_csv) | |
| if args.text_column not in df.columns: | |
| print(f"Error: Column '{args.text_column}' not found in {args.input_csv}. Available columns: {df.columns.tolist()}") | |
| return | |
| # Take a sample if the dataset is huge, since local inference on CPU takes time | |
| if len(df) > args.sample_size: | |
| print(f"Sampling {args.sample_size} reviews from a total of {len(df)}...") | |
| df = df.sample(n=args.sample_size, random_state=42).copy() | |
| predictor = ABSAPredictor() | |
| print("Auto-labeling reviews...") | |
| results_col = [] | |
| # Process each review | |
| for text in tqdm(df[args.text_column].fillna('')): | |
| text = str(text).strip() | |
| if not text: | |
| results_col.append("{}") | |
| continue | |
| predicted_aspects = predictor.predict(text) | |
| # Store the JSON dict as a string in the dataframe | |
| results_col.append(json.dumps(predicted_aspects)) | |
| df['predicted_aspects'] = results_col | |
| os.makedirs(os.path.dirname(args.output_csv), exist_ok=True) | |
| df.to_csv(args.output_csv, index=False) | |
| print(f"\nSuccessfully labeled {len(df)} reviews!") | |
| print(f"Saved to: {args.output_csv}") | |
| if __name__ == "__main__": | |
| main() | |