Spaces:
Sleeping
Sleeping
| import argparse | |
| import pandas as pd | |
| from src.inference import ClusterPredictor | |
| import os | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Predict personas for crypto wallets.") | |
| parser.add_argument("--file", type=str, help="Path to a CSV file containing wallet features.") | |
| parser.add_argument("--row", type=int, default=0, help="Row index to use from the dataset if no file is provided (uses wallet_dataset.csv).") | |
| args = parser.parse_args() | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_PATH = os.path.join(BASE_DIR, "kmeans_model.pkl") | |
| PREPROCESSOR_PATH = os.path.join(BASE_DIR, "wallet_power_transformer.pkl") | |
| DEFAULT_DATA_PATH = os.path.join(BASE_DIR, "wallet_dataset.csv") | |
| try: | |
| predictor = ClusterPredictor(model_path=MODEL_PATH, preprocessor_path=PREPROCESSOR_PATH) | |
| except FileNotFoundError as e: | |
| print(f"Error: {e}") | |
| return | |
| if args.file: | |
| input_path = args.file | |
| if not os.path.exists(input_path): | |
| print(f"Error: File {input_path} not found.") | |
| return | |
| print(f"Processing file: {input_path}") | |
| df = pd.read_csv(input_path) | |
| try: | |
| results = predictor.predict(df) | |
| if isinstance(results, list): | |
| df['Predicted_Cluster'] = [r['cluster_label'] for r in results] | |
| df['Predicted_Persona'] = [r['persona'] for r in results] | |
| else: | |
| df['Predicted_Cluster'] = results['cluster_label'] | |
| df['Predicted_Persona'] = results['persona'] | |
| output_path = input_path.replace(".csv", "_predicted.csv") | |
| df.to_csv(output_path, index=False) | |
| print(f"Predictions saved to {output_path}") | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| else: | |
| print(f"No input file provided. Using row {args.row} from {DEFAULT_DATA_PATH} as a sample.") | |
| if not os.path.exists(DEFAULT_DATA_PATH): | |
| print("Default dataset not found.") | |
| return | |
| df = pd.read_csv(DEFAULT_DATA_PATH) | |
| if args.row >= len(df): | |
| print(f"Row {args.row} out of bounds.") | |
| return | |
| sample = df.iloc[args.row].to_dict() | |
| print("\n--- Sample Features ---") | |
| for k, v in sample.items(): | |
| if k in predictor.FEATURES: | |
| print(f"{k}: {v}") | |
| try: | |
| result = predictor.predict(sample) | |
| print("\n--- Prediction Result ---") | |
| print(f"Cluster: {result['cluster_label']}") | |
| print(f"Persona: {result['persona']}") | |
| print("\n--- Confidence Scores ---") | |
| sorted_probs = sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True) | |
| for persona, score in sorted_probs: | |
| print(f"{persona}: {score:.4f}") | |
| except Exception as e: | |
| print(f"Prediction error: {e}") | |
| if __name__ == "__main__": | |
| main() | |