File size: 4,441 Bytes
24c3333 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from pymongo import MongoClient
import requests
import random
from sklearn.preprocessing import StandardScaler
import os
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# MongoDB Connection Setup
def get_mongo_connection():
MONGO_URI = os.getenv("MONGODB_URI")
DB_NAME = "walletAnalyzer"
COLLECTION_NAME = "wallets"
client = MongoClient(
MONGO_URI,
tls=True,
tlsAllowInvalidCertificates=True
)
db = client[DB_NAME]
return db[COLLECTION_NAME]
collection = get_mongo_connection()
# Fetch Aggregated Data from API
def fetch_aggregated_data(api_url):
response = requests.get(api_url)
return response.json()
# Extract States from Wallet Data
def extract_state(wallet, api_data):
if wallet.get("topHoldings") is None or not wallet["topHoldings"]:
return None # Skip wallets without topHoldings
market_caps = [float(h.get('marketCap', 0)) for h in wallet['topHoldings']]
token_prices = [float(h.get('price', 0)) for h in wallet['topHoldings']]
token_balances = [float(h.get('balance', 0)) for h in wallet['topHoldings']]
avg_market_cap = sum(market_caps) / len(market_caps) if market_caps else 0
# Ensure totalValue is a float
total_value = wallet['totalValue']
if isinstance(total_value, dict) and '$numberDouble' in total_value:
total_value = float(total_value['$numberDouble'])
state = [
total_value, # Total portfolio value
len(wallet['topHoldings']), # Number of holdings
avg_market_cap, # Average market cap of top holdings
api_data['portfolioMetrics']['averagePortfolioValue'], # Average portfolio value
api_data['portfolioMetrics']['totalPortfolioValue'] # Total portfolio value
] + token_balances + token_prices
return np.array(state)
# Normalize Features
def normalize_states(states, scaler=None):
if scaler is None:
scaler = StandardScaler()
states = scaler.fit_transform(states)
else:
states = scaler.transform(states)
return states, scaler
# Load and Test the Trained Model
def test_model_on_random_wallet(api_url, model_path):
# Load the trained model without compilation
model = tf.keras.models.load_model(model_path, compile=False)
print("Model loaded successfully.")
# Recompile the model with the correct loss function
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
print("Model recompiled with MSE loss.")
# Fetch API data
api_data = fetch_aggregated_data(api_url)
# Retry fetching a wallet with topHoldings
max_retries = 10
for attempt in range(max_retries):
random_wallet = collection.aggregate([{"$sample": {"size": 1}}]).next()
state = extract_state(random_wallet, api_data)
if state is not None:
break
print(f"Attempt {attempt + 1}: Wallet has no topHoldings. Retrying...")
else:
print("Failed to fetch a valid wallet with topHoldings after multiple attempts.")
return
# Normalize the state (use a pre-fitted scaler if available)
states, scaler = normalize_states([state])
normalized_state = states[0] # Extract the first (and only) normalized state
# Predict actions using the model
q_values = model.predict(normalized_state.reshape(1, -1))[0]
# Iterate over all tokens in the wallet and output the best action
print(f"Testing on wallet: {random_wallet['address']}")
print(f"Q-values: {q_values}")
for token_index, holding in enumerate(random_wallet["topHoldings"]):
# Determine the best action for this token
actions = ["Buy", "Sell", "Hold"]
token_q_values = q_values[token_index * 3: (token_index + 1) * 3]
best_action_index = np.argmax(token_q_values)
best_action = actions[best_action_index]
# Output the action for the token
token_symbol = holding["symbol"]
print(f"Token: {token_symbol}, Best action: {best_action}, Q-values: {token_q_values}")
if __name__ == "__main__":
API_URL = "https://soltrendio.com/api/stats/getTrends" # Replace with your actual API URL
MODEL_PATH = "rl_wallet_model.h5" # Path to the trained model
# Test the model on a random wallet
test_model_on_random_wallet(API_URL, MODEL_PATH) |