File size: 4,441 Bytes
24c3333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from pymongo import MongoClient
import requests
import random
from sklearn.preprocessing import StandardScaler
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# MongoDB Connection Setup
def get_mongo_connection():
    MONGO_URI = os.getenv("MONGODB_URI")
    DB_NAME = "walletAnalyzer"
    COLLECTION_NAME = "wallets"

    client = MongoClient(
        MONGO_URI,
        tls=True,
        tlsAllowInvalidCertificates=True
    )
    db = client[DB_NAME]
    return db[COLLECTION_NAME]

collection = get_mongo_connection()

# Fetch Aggregated Data from API
def fetch_aggregated_data(api_url):
    response = requests.get(api_url)
    return response.json()

# Extract States from Wallet Data
def extract_state(wallet, api_data):
    if wallet.get("topHoldings") is None or not wallet["topHoldings"]:
        return None  # Skip wallets without topHoldings

    market_caps = [float(h.get('marketCap', 0)) for h in wallet['topHoldings']]
    token_prices = [float(h.get('price', 0)) for h in wallet['topHoldings']]
    token_balances = [float(h.get('balance', 0)) for h in wallet['topHoldings']]

    avg_market_cap = sum(market_caps) / len(market_caps) if market_caps else 0

    # Ensure totalValue is a float
    total_value = wallet['totalValue']
    if isinstance(total_value, dict) and '$numberDouble' in total_value:
        total_value = float(total_value['$numberDouble'])

    state = [
        total_value,  # Total portfolio value
        len(wallet['topHoldings']),  # Number of holdings
        avg_market_cap,  # Average market cap of top holdings
        api_data['portfolioMetrics']['averagePortfolioValue'],  # Average portfolio value
        api_data['portfolioMetrics']['totalPortfolioValue']  # Total portfolio value
    ] + token_balances + token_prices

    return np.array(state)

# Normalize Features
def normalize_states(states, scaler=None):
    if scaler is None:
        scaler = StandardScaler()
        states = scaler.fit_transform(states)
    else:
        states = scaler.transform(states)
    return states, scaler

# Load and Test the Trained Model
def test_model_on_random_wallet(api_url, model_path):
    # Load the trained model without compilation
    model = tf.keras.models.load_model(model_path, compile=False)
    print("Model loaded successfully.")

    # Recompile the model with the correct loss function
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
    print("Model recompiled with MSE loss.")

    # Fetch API data
    api_data = fetch_aggregated_data(api_url)

    # Retry fetching a wallet with topHoldings
    max_retries = 10
    for attempt in range(max_retries):
        random_wallet = collection.aggregate([{"$sample": {"size": 1}}]).next()
        state = extract_state(random_wallet, api_data)
        if state is not None:
            break
        print(f"Attempt {attempt + 1}: Wallet has no topHoldings. Retrying...")
    else:
        print("Failed to fetch a valid wallet with topHoldings after multiple attempts.")
        return

    # Normalize the state (use a pre-fitted scaler if available)
    states, scaler = normalize_states([state])
    normalized_state = states[0]  # Extract the first (and only) normalized state

    # Predict actions using the model
    q_values = model.predict(normalized_state.reshape(1, -1))[0]

    # Iterate over all tokens in the wallet and output the best action
    print(f"Testing on wallet: {random_wallet['address']}")
    print(f"Q-values: {q_values}")

    for token_index, holding in enumerate(random_wallet["topHoldings"]):
        # Determine the best action for this token
        actions = ["Buy", "Sell", "Hold"]
        token_q_values = q_values[token_index * 3: (token_index + 1) * 3]
        best_action_index = np.argmax(token_q_values)
        best_action = actions[best_action_index]

        # Output the action for the token
        token_symbol = holding["symbol"]
        print(f"Token: {token_symbol}, Best action: {best_action}, Q-values: {token_q_values}")

if __name__ == "__main__":
    API_URL = "https://soltrendio.com/api/stats/getTrends"  # Replace with your actual API URL
    MODEL_PATH = "rl_wallet_model.h5"  # Path to the trained model

    # Test the model on a random wallet
    test_model_on_random_wallet(API_URL, MODEL_PATH)