Spaces:
Sleeping
Sleeping
File size: 3,080 Bytes
11f6345 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import argparse
import pandas as pd
from src.inference import ClusterPredictor
import os
def main():
parser = argparse.ArgumentParser(description="Predict personas for crypto wallets.")
parser.add_argument("--file", type=str, help="Path to a CSV file containing wallet features.")
parser.add_argument("--row", type=int, default=0, help="Row index to use from the dataset if no file is provided (uses wallet_dataset.csv).")
args = parser.parse_args()
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.path.join(BASE_DIR, "kmeans_model.pkl")
PREPROCESSOR_PATH = os.path.join(BASE_DIR, "wallet_power_transformer.pkl")
DEFAULT_DATA_PATH = os.path.join(BASE_DIR, "wallet_dataset.csv")
try:
predictor = ClusterPredictor(model_path=MODEL_PATH, preprocessor_path=PREPROCESSOR_PATH)
except FileNotFoundError as e:
print(f"Error: {e}")
return
if args.file:
input_path = args.file
if not os.path.exists(input_path):
print(f"Error: File {input_path} not found.")
return
print(f"Processing file: {input_path}")
df = pd.read_csv(input_path)
try:
results = predictor.predict(df)
if isinstance(results, list):
df['Predicted_Cluster'] = [r['cluster_label'] for r in results]
df['Predicted_Persona'] = [r['persona'] for r in results]
else:
df['Predicted_Cluster'] = results['cluster_label']
df['Predicted_Persona'] = results['persona']
output_path = input_path.replace(".csv", "_predicted.csv")
df.to_csv(output_path, index=False)
print(f"Predictions saved to {output_path}")
except Exception as e:
print(f"Prediction error: {e}")
else:
print(f"No input file provided. Using row {args.row} from {DEFAULT_DATA_PATH} as a sample.")
if not os.path.exists(DEFAULT_DATA_PATH):
print("Default dataset not found.")
return
df = pd.read_csv(DEFAULT_DATA_PATH)
if args.row >= len(df):
print(f"Row {args.row} out of bounds.")
return
sample = df.iloc[args.row].to_dict()
print("\n--- Sample Features ---")
for k, v in sample.items():
if k in predictor.FEATURES:
print(f"{k}: {v}")
try:
result = predictor.predict(sample)
print("\n--- Prediction Result ---")
print(f"Cluster: {result['cluster_label']}")
print(f"Persona: {result['persona']}")
print("\n--- Confidence Scores ---")
sorted_probs = sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True)
for persona, score in sorted_probs:
print(f"{persona}: {score:.4f}")
except Exception as e:
print(f"Prediction error: {e}")
if __name__ == "__main__":
main()
|