|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from tensorflow.keras import layers |
|
|
from pymongo import MongoClient |
|
|
import requests |
|
|
import random |
|
|
from sklearn.preprocessing import StandardScaler |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
def get_mongo_connection(): |
|
|
MONGO_URI = os.getenv("MONGODB_URI") |
|
|
DB_NAME = "walletAnalyzer" |
|
|
COLLECTION_NAME = "wallets" |
|
|
|
|
|
client = MongoClient( |
|
|
MONGO_URI, |
|
|
tls=True, |
|
|
tlsAllowInvalidCertificates=True |
|
|
) |
|
|
db = client[DB_NAME] |
|
|
return db[COLLECTION_NAME] |
|
|
|
|
|
collection = get_mongo_connection() |
|
|
|
|
|
|
|
|
def fetch_aggregated_data(api_url): |
|
|
response = requests.get(api_url) |
|
|
return response.json() |
|
|
|
|
|
|
|
|
def extract_state(wallet, api_data): |
|
|
if wallet.get("topHoldings") is None or not wallet["topHoldings"]: |
|
|
return None |
|
|
|
|
|
market_caps = [float(h.get('marketCap', 0)) for h in wallet['topHoldings']] |
|
|
token_prices = [float(h.get('price', 0)) for h in wallet['topHoldings']] |
|
|
token_balances = [float(h.get('balance', 0)) for h in wallet['topHoldings']] |
|
|
|
|
|
avg_market_cap = sum(market_caps) / len(market_caps) if market_caps else 0 |
|
|
|
|
|
|
|
|
total_value = wallet['totalValue'] |
|
|
if isinstance(total_value, dict) and '$numberDouble' in total_value: |
|
|
total_value = float(total_value['$numberDouble']) |
|
|
|
|
|
state = [ |
|
|
total_value, |
|
|
len(wallet['topHoldings']), |
|
|
avg_market_cap, |
|
|
api_data['portfolioMetrics']['averagePortfolioValue'], |
|
|
api_data['portfolioMetrics']['totalPortfolioValue'] |
|
|
] + token_balances + token_prices |
|
|
|
|
|
return np.array(state) |
|
|
|
|
|
|
|
|
def normalize_states(states, scaler=None): |
|
|
if scaler is None: |
|
|
scaler = StandardScaler() |
|
|
states = scaler.fit_transform(states) |
|
|
else: |
|
|
states = scaler.transform(states) |
|
|
return states, scaler |
|
|
|
|
|
|
|
|
def test_model_on_random_wallet(api_url, model_path): |
|
|
|
|
|
model = tf.keras.models.load_model(model_path, compile=False) |
|
|
print("Model loaded successfully.") |
|
|
|
|
|
|
|
|
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse') |
|
|
print("Model recompiled with MSE loss.") |
|
|
|
|
|
|
|
|
api_data = fetch_aggregated_data(api_url) |
|
|
|
|
|
|
|
|
max_retries = 10 |
|
|
for attempt in range(max_retries): |
|
|
random_wallet = collection.aggregate([{"$sample": {"size": 1}}]).next() |
|
|
state = extract_state(random_wallet, api_data) |
|
|
if state is not None: |
|
|
break |
|
|
print(f"Attempt {attempt + 1}: Wallet has no topHoldings. Retrying...") |
|
|
else: |
|
|
print("Failed to fetch a valid wallet with topHoldings after multiple attempts.") |
|
|
return |
|
|
|
|
|
|
|
|
states, scaler = normalize_states([state]) |
|
|
normalized_state = states[0] |
|
|
|
|
|
|
|
|
q_values = model.predict(normalized_state.reshape(1, -1))[0] |
|
|
|
|
|
|
|
|
print(f"Testing on wallet: {random_wallet['address']}") |
|
|
print(f"Q-values: {q_values}") |
|
|
|
|
|
for token_index, holding in enumerate(random_wallet["topHoldings"]): |
|
|
|
|
|
actions = ["Buy", "Sell", "Hold"] |
|
|
token_q_values = q_values[token_index * 3: (token_index + 1) * 3] |
|
|
best_action_index = np.argmax(token_q_values) |
|
|
best_action = actions[best_action_index] |
|
|
|
|
|
|
|
|
token_symbol = holding["symbol"] |
|
|
print(f"Token: {token_symbol}, Best action: {best_action}, Q-values: {token_q_values}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
API_URL = "https://soltrendio.com/api/stats/getTrends" |
|
|
MODEL_PATH = "rl_wallet_model.h5" |
|
|
|
|
|
|
|
|
test_model_on_random_wallet(API_URL, MODEL_PATH) |