File size: 3,080 Bytes
11f6345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import pandas as pd
from src.inference import ClusterPredictor
import os

def main():
    parser = argparse.ArgumentParser(description="Predict personas for crypto wallets.")
    parser.add_argument("--file", type=str, help="Path to a CSV file containing wallet features.")
    parser.add_argument("--row", type=int, default=0, help="Row index to use from the dataset if no file is provided (uses wallet_dataset.csv).")
    args = parser.parse_args()

    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    MODEL_PATH = os.path.join(BASE_DIR, "kmeans_model.pkl")
    PREPROCESSOR_PATH = os.path.join(BASE_DIR, "wallet_power_transformer.pkl")
    DEFAULT_DATA_PATH = os.path.join(BASE_DIR, "wallet_dataset.csv")

    try:
        predictor = ClusterPredictor(model_path=MODEL_PATH, preprocessor_path=PREPROCESSOR_PATH)
    except FileNotFoundError as e:
        print(f"Error: {e}")
        return

    if args.file:
        input_path = args.file
        if not os.path.exists(input_path):
            print(f"Error: File {input_path} not found.")
            return
        print(f"Processing file: {input_path}")
        df = pd.read_csv(input_path)
        
        try:
            results = predictor.predict(df)
            if isinstance(results, list):
                df['Predicted_Cluster'] = [r['cluster_label'] for r in results]
                df['Predicted_Persona'] = [r['persona'] for r in results]
            else:
                 df['Predicted_Cluster'] = results['cluster_label']
                 df['Predicted_Persona'] = results['persona']
                 
            output_path = input_path.replace(".csv", "_predicted.csv")
            df.to_csv(output_path, index=False)
            print(f"Predictions saved to {output_path}")
        except Exception as e:
            print(f"Prediction error: {e}")

    else:
        print(f"No input file provided. Using row {args.row} from {DEFAULT_DATA_PATH} as a sample.")
        if not os.path.exists(DEFAULT_DATA_PATH):
            print("Default dataset not found.")
            return
            
        df = pd.read_csv(DEFAULT_DATA_PATH)
        if args.row >= len(df):
            print(f"Row {args.row} out of bounds.")
            return
            
        sample = df.iloc[args.row].to_dict()
        print("\n--- Sample Features ---")
        for k, v in sample.items():
            if k in predictor.FEATURES:
                print(f"{k}: {v}")
                
        try:
            result = predictor.predict(sample)
            print("\n--- Prediction Result ---")
            print(f"Cluster: {result['cluster_label']}")
            print(f"Persona: {result['persona']}")
            
            print("\n--- Confidence Scores ---")
            sorted_probs = sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True)
            for persona, score in sorted_probs:
                print(f"{persona}: {score:.4f}")
                
        except Exception as e:
            print(f"Prediction error: {e}")

if __name__ == "__main__":
    main()