Upload 4 files
Browse files- ddqn.py +239 -0
- rl_wallet_model.h5 +3 -0
- test.py +122 -0
- train.py +280 -0
ddqn.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
from tensorflow.keras import layers
|
| 4 |
+
from pymongo import MongoClient
|
| 5 |
+
import requests
|
| 6 |
+
import random
|
| 7 |
+
from sklearn.preprocessing import StandardScaler
|
| 8 |
+
import os
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from langchain_openai import OpenAI
|
| 11 |
+
from langchain.prompts import PromptTemplate
|
| 12 |
+
from operator import itemgetter
|
| 13 |
+
|
| 14 |
+
# Load environment variables from .env file
|
| 15 |
+
load_dotenv()
|
| 16 |
+
API_URL = "https://soltrendio.com/api/stats/getTrends" # Replace with your actual API URL
|
| 17 |
+
MODEL_PATH = "rl_wallet_model.h5" # Path to the trained model
|
| 18 |
+
# MongoDB Connection Setup
|
| 19 |
+
def get_mongo_connection():
|
| 20 |
+
MONGO_URI = os.getenv("MONGODB_URI")
|
| 21 |
+
print(f"Attempting to connect to MongoDB with URI: {MONGO_URI[:20]}...") # Only show start of URI for security
|
| 22 |
+
|
| 23 |
+
DB_NAME = "walletAnalyzer"
|
| 24 |
+
COLLECTION_NAME = "wallets"
|
| 25 |
+
|
| 26 |
+
client = MongoClient(
|
| 27 |
+
MONGO_URI,
|
| 28 |
+
tls=True,
|
| 29 |
+
tlsAllowInvalidCertificates=True
|
| 30 |
+
)
|
| 31 |
+
db = client[DB_NAME]
|
| 32 |
+
print("Successfully connected to MongoDB")
|
| 33 |
+
return db[COLLECTION_NAME]
|
| 34 |
+
|
| 35 |
+
collection = get_mongo_connection()
|
| 36 |
+
|
| 37 |
+
# Fetch Aggregated Data from API
|
| 38 |
+
def fetch_aggregated_data(api_url):
|
| 39 |
+
print(f"Fetching data from API: {api_url}")
|
| 40 |
+
response = requests.get(api_url)
|
| 41 |
+
print(f"API Response status code: {response.status_code}")
|
| 42 |
+
return response.json()
|
| 43 |
+
|
| 44 |
+
# Extract States from Wallet Data
|
| 45 |
+
def extract_state(wallet, api_data):
|
| 46 |
+
print(f"\nExtracting state for wallet: {wallet.get('address', 'Unknown address')}")
|
| 47 |
+
|
| 48 |
+
if wallet.get("topHoldings") is None or not wallet["topHoldings"]:
|
| 49 |
+
print("Warning: Wallet has no topHoldings")
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
print(f"Number of top holdings: {len(wallet['topHoldings'])}")
|
| 53 |
+
|
| 54 |
+
# Pad or truncate topHoldings to exactly 8 tokens
|
| 55 |
+
MAX_TOKENS = 8
|
| 56 |
+
padded_holdings = wallet['topHoldings'][:MAX_TOKENS] # Truncate if more than 8
|
| 57 |
+
while len(padded_holdings) < MAX_TOKENS: # Pad with empty holdings if less than 8
|
| 58 |
+
padded_holdings.append({
|
| 59 |
+
'marketCap': '0',
|
| 60 |
+
'price': '0',
|
| 61 |
+
'balance': '0',
|
| 62 |
+
'symbol': 'EMPTY'
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
market_caps = [float(h.get('marketCap', 0)) for h in padded_holdings]
|
| 66 |
+
token_prices = [float(h.get('price', 0)) for h in padded_holdings]
|
| 67 |
+
token_balances = [float(h.get('balance', 0)) for h in padded_holdings]
|
| 68 |
+
|
| 69 |
+
print(f"Market caps: {market_caps}")
|
| 70 |
+
print(f"Token prices: {token_prices}")
|
| 71 |
+
print(f"Token balances: {token_balances}")
|
| 72 |
+
|
| 73 |
+
avg_market_cap = sum(market_caps) / len(market_caps) if market_caps else 0
|
| 74 |
+
print(f"Average market cap: {avg_market_cap}")
|
| 75 |
+
|
| 76 |
+
# Ensure totalValue is a float
|
| 77 |
+
total_value = wallet['totalValue']
|
| 78 |
+
if isinstance(total_value, dict) and '$numberDouble' in total_value:
|
| 79 |
+
total_value = float(total_value['$numberDouble'])
|
| 80 |
+
print(f"Total wallet value: {total_value}")
|
| 81 |
+
|
| 82 |
+
# Create state vector with exactly 25 features:
|
| 83 |
+
# 5 portfolio metrics + (8 tokens × 2.5 features per token = 20 features)
|
| 84 |
+
state = [
|
| 85 |
+
total_value, # Total portfolio value
|
| 86 |
+
len(wallet['topHoldings']), # Original number of holdings
|
| 87 |
+
avg_market_cap, # Average market cap of top holdings
|
| 88 |
+
api_data['portfolioMetrics']['averagePortfolioValue'], # Average portfolio value
|
| 89 |
+
api_data['portfolioMetrics']['totalPortfolioValue'] # Total portfolio value
|
| 90 |
+
] + token_balances + token_prices # 8 balances + 8 prices = 16 features
|
| 91 |
+
|
| 92 |
+
# Total features: 5 + 8 + 8 = 21 features
|
| 93 |
+
# Add 4 more features to reach 25 (you might want to adjust these based on your model's requirements)
|
| 94 |
+
state.extend([0.0] * 4) # Padding with zeros to reach 25 features
|
| 95 |
+
|
| 96 |
+
print(f"Final state vector shape: {len(state)}")
|
| 97 |
+
print(f"State vector: {state}\n")
|
| 98 |
+
return np.array(state)
|
| 99 |
+
|
| 100 |
+
# Normalize Features
|
| 101 |
+
def normalize_states(states, scaler=None):
|
| 102 |
+
print("\nNormalizing states...")
|
| 103 |
+
# Convert list to numpy array if it isn't already
|
| 104 |
+
states = np.array(states)
|
| 105 |
+
print(f"Input states shape: {states.shape}")
|
| 106 |
+
|
| 107 |
+
if scaler is None:
|
| 108 |
+
scaler = StandardScaler()
|
| 109 |
+
states = scaler.fit_transform(states)
|
| 110 |
+
print("Created new scaler and fit_transformed states")
|
| 111 |
+
else:
|
| 112 |
+
states = scaler.transform(states)
|
| 113 |
+
print("Used existing scaler to transform states")
|
| 114 |
+
|
| 115 |
+
print(f"Normalized states shape: {states.shape}")
|
| 116 |
+
print(f"Normalized states sample: {states[0][:5]}...\n")
|
| 117 |
+
return states, scaler
|
| 118 |
+
|
| 119 |
+
# Add this new function to create the summary
|
| 120 |
+
def generate_trading_summary(wallet_address, holdings_analysis):
|
| 121 |
+
"""
|
| 122 |
+
Generate a natural language summary of trading recommendations using Langchain.
|
| 123 |
+
"""
|
| 124 |
+
try:
|
| 125 |
+
llm = OpenAI(temperature=0.7, openai_api_key=os.getenv("OPENAI_API_KEY"))
|
| 126 |
+
|
| 127 |
+
# Create a prompt template
|
| 128 |
+
template = """
|
| 129 |
+
As a crypto trading advisor, analyze the following wallet and its DDQN holdings analysis:
|
| 130 |
+
|
| 131 |
+
Wallet Address: {wallet_address}
|
| 132 |
+
|
| 133 |
+
Holdings Analysis:
|
| 134 |
+
{holdings_details}
|
| 135 |
+
|
| 136 |
+
Please provide a concise summary of the recommended trading actions, including:
|
| 137 |
+
1. Overall portfolio assessment
|
| 138 |
+
2. Specific recommendations for each token
|
| 139 |
+
3. Key opportunities and risks
|
| 140 |
+
|
| 141 |
+
Summary:
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
# Create prompt with template
|
| 145 |
+
prompt = PromptTemplate(
|
| 146 |
+
input_variables=["wallet_address", "holdings_details"],
|
| 147 |
+
template=template
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Create and run the chain using the new pattern
|
| 151 |
+
chain = (
|
| 152 |
+
{"wallet_address": itemgetter("wallet_address"),
|
| 153 |
+
"holdings_details": itemgetter("holdings_details")}
|
| 154 |
+
| prompt
|
| 155 |
+
| llm
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Invoke the chain with the new pattern
|
| 159 |
+
summary = chain.invoke({
|
| 160 |
+
"wallet_address": wallet_address,
|
| 161 |
+
"holdings_details": holdings_analysis
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
print("\n=== AI Generated Trading Summary ===")
|
| 165 |
+
print(summary)
|
| 166 |
+
return summary
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
print(f"Error generating summary: {str(e)}")
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
# Load and Test the Trained Model
|
| 173 |
+
def test_model_on_wallet(wallet_identifier, api_url, model_path):
|
| 174 |
+
# Load the trained model without compilation
|
| 175 |
+
model = tf.keras.models.load_model(model_path, compile=False)
|
| 176 |
+
print("Model loaded successfully.")
|
| 177 |
+
|
| 178 |
+
# Recompile the model with the correct loss function
|
| 179 |
+
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
|
| 180 |
+
print("Model recompiled with MSE loss.")
|
| 181 |
+
|
| 182 |
+
# Fetch API data
|
| 183 |
+
api_data = fetch_aggregated_data(api_url)
|
| 184 |
+
|
| 185 |
+
# Fetch wallet from database
|
| 186 |
+
wallet = collection.find_one({"$or": [{"address": wallet_identifier}, {"domain": wallet_identifier}]})
|
| 187 |
+
if wallet is None:
|
| 188 |
+
print(f"Wallet with identifier {wallet_identifier} not found in the database.")
|
| 189 |
+
return
|
| 190 |
+
|
| 191 |
+
# Extract state
|
| 192 |
+
state = extract_state(wallet, api_data)
|
| 193 |
+
if state is None:
|
| 194 |
+
print(f"Wallet {wallet_identifier} has no topHoldings or insufficient data.")
|
| 195 |
+
return
|
| 196 |
+
|
| 197 |
+
# Normalize the state (use a pre-fitted scaler if available)
|
| 198 |
+
states, scaler = normalize_states([state])
|
| 199 |
+
normalized_state = states[0] # Extract the first (and only) normalized state
|
| 200 |
+
|
| 201 |
+
# Predict actions using the model
|
| 202 |
+
q_values = model.predict(normalized_state.reshape(1, -1))[0]
|
| 203 |
+
|
| 204 |
+
# Iterate over all tokens in the wallet and output the best action
|
| 205 |
+
print(f"\nAnalyzing wallet: {wallet['address']}")
|
| 206 |
+
print(f"Q-values: {q_values}")
|
| 207 |
+
|
| 208 |
+
# Create a list to store analysis results
|
| 209 |
+
holdings_analysis = []
|
| 210 |
+
|
| 211 |
+
for token_index, holding in enumerate(wallet["topHoldings"]):
|
| 212 |
+
# Determine the best action for this token
|
| 213 |
+
actions = ["Buy", "Sell", "Hold"]
|
| 214 |
+
token_q_values = q_values[token_index * 3: (token_index + 1) * 3]
|
| 215 |
+
best_action_index = np.argmax(token_q_values)
|
| 216 |
+
best_action = actions[best_action_index]
|
| 217 |
+
|
| 218 |
+
# Calculate confidence as the difference between the best Q-value and the average of others
|
| 219 |
+
confidence = token_q_values[best_action_index] - np.mean(token_q_values)
|
| 220 |
+
|
| 221 |
+
# Store the analysis
|
| 222 |
+
holdings_analysis.append({
|
| 223 |
+
"symbol": holding["symbol"],
|
| 224 |
+
"balance": holding.get("balance", 0),
|
| 225 |
+
"action": best_action,
|
| 226 |
+
"confidence": confidence,
|
| 227 |
+
"q_values": token_q_values.tolist()
|
| 228 |
+
})
|
| 229 |
+
|
| 230 |
+
# Output the action for the token
|
| 231 |
+
print(f"Token: {holding['symbol']}, Best action: {best_action}, Q-values: {token_q_values}")
|
| 232 |
+
|
| 233 |
+
# Generate and display the trading summary
|
| 234 |
+
generate_trading_summary(wallet['address'], holdings_analysis)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
wallet_identifier = input("Enter the wallet address or .sol domain to analyze: ").strip()
|
| 239 |
+
test_model_on_wallet(wallet_identifier, API_URL, MODEL_PATH)
|
rl_wallet_model.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6a3055e68bb0f8eeb44e06cba1cb8fd3219cb8f1d74f5a65fb1a26684d1a8f24
|
| 3 |
+
size 145360
|
test.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
from tensorflow.keras import layers
|
| 4 |
+
from pymongo import MongoClient
|
| 5 |
+
import requests
|
| 6 |
+
import random
|
| 7 |
+
from sklearn.preprocessing import StandardScaler
|
| 8 |
+
import os
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
# Load environment variables from .env file
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
# MongoDB Connection Setup
|
| 15 |
+
def get_mongo_connection():
|
| 16 |
+
MONGO_URI = os.getenv("MONGODB_URI")
|
| 17 |
+
DB_NAME = "walletAnalyzer"
|
| 18 |
+
COLLECTION_NAME = "wallets"
|
| 19 |
+
|
| 20 |
+
client = MongoClient(
|
| 21 |
+
MONGO_URI,
|
| 22 |
+
tls=True,
|
| 23 |
+
tlsAllowInvalidCertificates=True
|
| 24 |
+
)
|
| 25 |
+
db = client[DB_NAME]
|
| 26 |
+
return db[COLLECTION_NAME]
|
| 27 |
+
|
| 28 |
+
collection = get_mongo_connection()
|
| 29 |
+
|
| 30 |
+
# Fetch Aggregated Data from API
|
| 31 |
+
def fetch_aggregated_data(api_url):
|
| 32 |
+
response = requests.get(api_url)
|
| 33 |
+
return response.json()
|
| 34 |
+
|
| 35 |
+
# Extract States from Wallet Data
|
| 36 |
+
def extract_state(wallet, api_data):
|
| 37 |
+
if wallet.get("topHoldings") is None or not wallet["topHoldings"]:
|
| 38 |
+
return None # Skip wallets without topHoldings
|
| 39 |
+
|
| 40 |
+
market_caps = [float(h.get('marketCap', 0)) for h in wallet['topHoldings']]
|
| 41 |
+
token_prices = [float(h.get('price', 0)) for h in wallet['topHoldings']]
|
| 42 |
+
token_balances = [float(h.get('balance', 0)) for h in wallet['topHoldings']]
|
| 43 |
+
|
| 44 |
+
avg_market_cap = sum(market_caps) / len(market_caps) if market_caps else 0
|
| 45 |
+
|
| 46 |
+
# Ensure totalValue is a float
|
| 47 |
+
total_value = wallet['totalValue']
|
| 48 |
+
if isinstance(total_value, dict) and '$numberDouble' in total_value:
|
| 49 |
+
total_value = float(total_value['$numberDouble'])
|
| 50 |
+
|
| 51 |
+
state = [
|
| 52 |
+
total_value, # Total portfolio value
|
| 53 |
+
len(wallet['topHoldings']), # Number of holdings
|
| 54 |
+
avg_market_cap, # Average market cap of top holdings
|
| 55 |
+
api_data['portfolioMetrics']['averagePortfolioValue'], # Average portfolio value
|
| 56 |
+
api_data['portfolioMetrics']['totalPortfolioValue'] # Total portfolio value
|
| 57 |
+
] + token_balances + token_prices
|
| 58 |
+
|
| 59 |
+
return np.array(state)
|
| 60 |
+
|
| 61 |
+
# Normalize Features
|
| 62 |
+
def normalize_states(states, scaler=None):
|
| 63 |
+
if scaler is None:
|
| 64 |
+
scaler = StandardScaler()
|
| 65 |
+
states = scaler.fit_transform(states)
|
| 66 |
+
else:
|
| 67 |
+
states = scaler.transform(states)
|
| 68 |
+
return states, scaler
|
| 69 |
+
|
| 70 |
+
# Load and Test the Trained Model
|
| 71 |
+
def test_model_on_random_wallet(api_url, model_path):
|
| 72 |
+
# Load the trained model without compilation
|
| 73 |
+
model = tf.keras.models.load_model(model_path, compile=False)
|
| 74 |
+
print("Model loaded successfully.")
|
| 75 |
+
|
| 76 |
+
# Recompile the model with the correct loss function
|
| 77 |
+
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
|
| 78 |
+
print("Model recompiled with MSE loss.")
|
| 79 |
+
|
| 80 |
+
# Fetch API data
|
| 81 |
+
api_data = fetch_aggregated_data(api_url)
|
| 82 |
+
|
| 83 |
+
# Retry fetching a wallet with topHoldings
|
| 84 |
+
max_retries = 10
|
| 85 |
+
for attempt in range(max_retries):
|
| 86 |
+
random_wallet = collection.aggregate([{"$sample": {"size": 1}}]).next()
|
| 87 |
+
state = extract_state(random_wallet, api_data)
|
| 88 |
+
if state is not None:
|
| 89 |
+
break
|
| 90 |
+
print(f"Attempt {attempt + 1}: Wallet has no topHoldings. Retrying...")
|
| 91 |
+
else:
|
| 92 |
+
print("Failed to fetch a valid wallet with topHoldings after multiple attempts.")
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
# Normalize the state (use a pre-fitted scaler if available)
|
| 96 |
+
states, scaler = normalize_states([state])
|
| 97 |
+
normalized_state = states[0] # Extract the first (and only) normalized state
|
| 98 |
+
|
| 99 |
+
# Predict actions using the model
|
| 100 |
+
q_values = model.predict(normalized_state.reshape(1, -1))[0]
|
| 101 |
+
|
| 102 |
+
# Iterate over all tokens in the wallet and output the best action
|
| 103 |
+
print(f"Testing on wallet: {random_wallet['address']}")
|
| 104 |
+
print(f"Q-values: {q_values}")
|
| 105 |
+
|
| 106 |
+
for token_index, holding in enumerate(random_wallet["topHoldings"]):
|
| 107 |
+
# Determine the best action for this token
|
| 108 |
+
actions = ["Buy", "Sell", "Hold"]
|
| 109 |
+
token_q_values = q_values[token_index * 3: (token_index + 1) * 3]
|
| 110 |
+
best_action_index = np.argmax(token_q_values)
|
| 111 |
+
best_action = actions[best_action_index]
|
| 112 |
+
|
| 113 |
+
# Output the action for the token
|
| 114 |
+
token_symbol = holding["symbol"]
|
| 115 |
+
print(f"Token: {token_symbol}, Best action: {best_action}, Q-values: {token_q_values}")
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
API_URL = "https://soltrendio.com/api/stats/getTrends" # Replace with your actual API URL
|
| 119 |
+
MODEL_PATH = "rl_wallet_model.h5" # Path to the trained model
|
| 120 |
+
|
| 121 |
+
# Test the model on a random wallet
|
| 122 |
+
test_model_on_random_wallet(API_URL, MODEL_PATH)
|
train.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
from tensorflow.keras import layers
|
| 4 |
+
from pymongo import MongoClient
|
| 5 |
+
import requests
|
| 6 |
+
import random
|
| 7 |
+
from collections import deque
|
| 8 |
+
from sklearn.preprocessing import StandardScaler
|
| 9 |
+
import os
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
# Load environment variables from .env file
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
# MongoDB Connection Setup
|
| 16 |
+
# Define connection parameters and establish connection to MongoDB
|
| 17 |
+
def get_mongo_connection():
|
| 18 |
+
MONGO_URI = os.getenv("MONGODB_URI")
|
| 19 |
+
DB_NAME = "walletAnalyzer"
|
| 20 |
+
COLLECTION_NAME = "wallets"
|
| 21 |
+
|
| 22 |
+
client = MongoClient(
|
| 23 |
+
MONGO_URI,
|
| 24 |
+
tls=True,
|
| 25 |
+
tlsAllowInvalidCertificates=True
|
| 26 |
+
)
|
| 27 |
+
db = client[DB_NAME]
|
| 28 |
+
return db[COLLECTION_NAME]
|
| 29 |
+
|
| 30 |
+
collection = get_mongo_connection()
|
| 31 |
+
|
| 32 |
+
# DexScreener API URL
|
| 33 |
+
DEX_API_URL = "https://api.dexscreener.com/latest/dex/tokens"
|
| 34 |
+
|
| 35 |
+
# Fetch Aggregated Data from API
|
| 36 |
+
# Retrieve portfolio and token trends from external API
|
| 37 |
+
def fetch_aggregated_data(api_url):
|
| 38 |
+
response = requests.get(api_url)
|
| 39 |
+
return response.json()
|
| 40 |
+
|
| 41 |
+
# Retrieve Contract Address for Token
|
| 42 |
+
# Find the contract address of a Solana token by its ticker symbol
|
| 43 |
+
def get_solana_token_ca(ticker):
|
| 44 |
+
try:
|
| 45 |
+
url = f"https://api.dexscreener.com/latest/dex/search?q={ticker}"
|
| 46 |
+
response = requests.get(url)
|
| 47 |
+
response.raise_for_status()
|
| 48 |
+
data = response.json()
|
| 49 |
+
|
| 50 |
+
first_found_address = None
|
| 51 |
+
for pair in data.get('pairs', []):
|
| 52 |
+
if pair.get('chainId') == 'solana':
|
| 53 |
+
base_token = pair.get('baseToken', {})
|
| 54 |
+
quote_token = pair.get('quoteToken', {})
|
| 55 |
+
|
| 56 |
+
if first_found_address is None:
|
| 57 |
+
first_found_address = base_token.get('address') or quote_token.get('address')
|
| 58 |
+
|
| 59 |
+
if base_token.get('symbol').upper() == ticker.upper():
|
| 60 |
+
return base_token.get('address')
|
| 61 |
+
elif quote_token.get('symbol').upper() == ticker.upper():
|
| 62 |
+
return quote_token.get('address')
|
| 63 |
+
|
| 64 |
+
return first_found_address
|
| 65 |
+
except requests.RequestException as e:
|
| 66 |
+
print(f"Error fetching token data: {e}")
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
# Fetch Market Data for Token
|
| 70 |
+
# Retrieve market cap and price data for a token contract address
|
| 71 |
+
def fetch_market_data(contract_address):
|
| 72 |
+
try:
|
| 73 |
+
response = requests.get(f"{DEX_API_URL}/{contract_address}")
|
| 74 |
+
if response.status_code == 200:
|
| 75 |
+
data = response.json()
|
| 76 |
+
if "pairs" in data and len(data["pairs"]) > 0:
|
| 77 |
+
pair = data["pairs"][0]
|
| 78 |
+
return {
|
| 79 |
+
"marketCap": pair.get("marketCap"),
|
| 80 |
+
"price": pair.get("priceUsd"),
|
| 81 |
+
}
|
| 82 |
+
else:
|
| 83 |
+
print(f"Failed to fetch data for contract address {contract_address}: {response.status_code}")
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"Error fetching market data for contract address {contract_address}: {e}")
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
# Update Wallets in Database
|
| 89 |
+
# Enhance wallet data with token market information and contract addresses
|
| 90 |
+
def update_wallets():
|
| 91 |
+
wallets = collection.find()
|
| 92 |
+
for wallet in wallets:
|
| 93 |
+
updated = False
|
| 94 |
+
if wallet.get("topHoldings") is None or not wallet["topHoldings"]:
|
| 95 |
+
continue
|
| 96 |
+
for holding in wallet.get("topHoldings", []):
|
| 97 |
+
symbol = holding.get("symbol")
|
| 98 |
+
contract_address = holding.get("contractAddress")
|
| 99 |
+
|
| 100 |
+
if not contract_address:
|
| 101 |
+
contract_address = get_solana_token_ca(symbol)
|
| 102 |
+
if contract_address:
|
| 103 |
+
holding["contractAddress"] = contract_address
|
| 104 |
+
print(f"Updated {symbol} with contract address {contract_address}.")
|
| 105 |
+
updated = True
|
| 106 |
+
|
| 107 |
+
if "marketCap" not in holding or "price" not in holding:
|
| 108 |
+
market_data = fetch_market_data(contract_address)
|
| 109 |
+
if market_data:
|
| 110 |
+
holding["marketCap"] = market_data["marketCap"]
|
| 111 |
+
holding["price"] = market_data["price"]
|
| 112 |
+
updated = True
|
| 113 |
+
print(f"Updated {symbol} in wallet {wallet['address']} with marketCap and price.")
|
| 114 |
+
|
| 115 |
+
if updated:
|
| 116 |
+
collection.update_one({"_id": wallet["_id"]}, {"$set": {"topHoldings": wallet["topHoldings"]}})
|
| 117 |
+
|
| 118 |
+
# Extract States from Wallet Data
|
| 119 |
+
# Convert wallet data into numerical state vectors for RL training
|
| 120 |
+
def extract_state(wallet, api_data, max_tokens=10):
|
| 121 |
+
if wallet.get("topHoldings") is None or not wallet["topHoldings"]:
|
| 122 |
+
return None # Skip wallets without topHoldings
|
| 123 |
+
|
| 124 |
+
# Extract market caps, token prices, and balances
|
| 125 |
+
market_caps = [float(h.get('marketCap', 0)) for h in wallet['topHoldings'][:max_tokens]]
|
| 126 |
+
token_prices = [float(h.get('price', 0)) for h in wallet['topHoldings'][:max_tokens]]
|
| 127 |
+
token_balances = [float(h.get('balance', 0)) for h in wallet['topHoldings'][:max_tokens]]
|
| 128 |
+
|
| 129 |
+
# Pad with zeros to ensure fixed length
|
| 130 |
+
market_caps += [0] * (max_tokens - len(market_caps))
|
| 131 |
+
token_prices += [0] * (max_tokens - len(token_prices))
|
| 132 |
+
token_balances += [0] * (max_tokens - len(token_balances))
|
| 133 |
+
|
| 134 |
+
avg_market_cap = sum(market_caps) / len(market_caps) if market_caps else 0
|
| 135 |
+
|
| 136 |
+
# Ensure totalValue is a float
|
| 137 |
+
total_value = wallet['totalValue']
|
| 138 |
+
if isinstance(total_value, dict) and '$numberDouble' in total_value:
|
| 139 |
+
total_value = float(total_value['$numberDouble'])
|
| 140 |
+
|
| 141 |
+
state = [
|
| 142 |
+
total_value, # Total portfolio value
|
| 143 |
+
len(wallet['topHoldings']), # Number of holdings
|
| 144 |
+
avg_market_cap, # Average market cap of top holdings
|
| 145 |
+
api_data['portfolioMetrics']['averagePortfolioValue'], # Average portfolio value
|
| 146 |
+
api_data['portfolioMetrics']['totalPortfolioValue'], # Total portfolio value
|
| 147 |
+
] + token_balances + token_prices # Include token balances and prices
|
| 148 |
+
|
| 149 |
+
return np.array(state)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Normalize Features
|
| 154 |
+
# Standardize state features for improved model training
|
| 155 |
+
def normalize_states(states):
|
| 156 |
+
scaler = StandardScaler()
|
| 157 |
+
return scaler.fit_transform(states)
|
| 158 |
+
|
| 159 |
+
# Create DQN Model
|
| 160 |
+
# Define the neural network for the RL agent
|
| 161 |
+
def create_dqn(state_size, action_size):
|
| 162 |
+
model = tf.keras.Sequential([
|
| 163 |
+
layers.Dense(64, activation='relu', input_shape=(state_size,)),
|
| 164 |
+
layers.Dense(64, activation='relu'),
|
| 165 |
+
layers.Dense(action_size, activation='linear')
|
| 166 |
+
])
|
| 167 |
+
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
|
| 168 |
+
return model
|
| 169 |
+
|
| 170 |
+
# Simulate Next State
|
| 171 |
+
# Generate the next state based on the action taken and simulated market changes
|
| 172 |
+
def simulate_next_state(state, action, token_prices, token_balances):
|
| 173 |
+
token_idx = action // 3 # Determine token index
|
| 174 |
+
action_type = action % 3 # Determine action type (Buy, Sell, Hold)
|
| 175 |
+
|
| 176 |
+
if action_type == 0: # Buy
|
| 177 |
+
state[token_idx + 5] += 1 # Increase token balance
|
| 178 |
+
elif action_type == 1: # Sell
|
| 179 |
+
state[token_idx + 5] -= 1 # Decrease token balance
|
| 180 |
+
|
| 181 |
+
# Simulate market changes in token prices
|
| 182 |
+
token_prices = [price * np.random.uniform(0.95, 1.05) for price in token_prices]
|
| 183 |
+
|
| 184 |
+
# Update total portfolio value based on new token balances and prices
|
| 185 |
+
total_value = sum(balance * price for balance, price in zip(token_balances, token_prices))
|
| 186 |
+
state[0] = total_value # Update total portfolio value
|
| 187 |
+
|
| 188 |
+
return state, token_prices
|
| 189 |
+
|
| 190 |
+
# Termination Logic
|
| 191 |
+
# Define conditions for ending an RL episode
|
| 192 |
+
def check_termination(state, steps, max_steps):
|
| 193 |
+
if state[0] <= 0 or steps >= max_steps:
|
| 194 |
+
return True
|
| 195 |
+
return False
|
| 196 |
+
|
| 197 |
+
# Reward Function
|
| 198 |
+
# Calculate rewards based on portfolio performance and risk
|
| 199 |
+
def calculate_reward(old_value, new_value, diversification_score):
|
| 200 |
+
value_change = (new_value - old_value) / old_value if old_value > 0 else 0
|
| 201 |
+
reward = value_change - (1 - diversification_score) # Penalize lack of diversification
|
| 202 |
+
return reward
|
| 203 |
+
|
| 204 |
+
# Main Training Loop
|
| 205 |
+
# Train the RL agent by simulating interactions with the environment
|
| 206 |
+
# Main Training Loop
|
| 207 |
+
def train_rl_model(api_url):
|
| 208 |
+
update_wallets() # Ensure wallets are updated before training
|
| 209 |
+
|
| 210 |
+
data = collection.find() # Fetch data directly from updated collection
|
| 211 |
+
api_data = fetch_aggregated_data(api_url)
|
| 212 |
+
|
| 213 |
+
# Extract valid states, skipping None
|
| 214 |
+
states = [
|
| 215 |
+
state for wallet in data
|
| 216 |
+
if (state := extract_state(wallet, api_data)) is not None
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
# Ensure states list is not empty before proceeding
|
| 220 |
+
if not states:
|
| 221 |
+
print("No valid states available for training.")
|
| 222 |
+
return
|
| 223 |
+
|
| 224 |
+
states = normalize_states(states)
|
| 225 |
+
|
| 226 |
+
state_size = len(states[0])
|
| 227 |
+
action_size = 3 * len(states[0][5:]) # Buy, Sell, Hold for each token
|
| 228 |
+
|
| 229 |
+
model = create_dqn(state_size, action_size)
|
| 230 |
+
replay_buffer = deque(maxlen=2000)
|
| 231 |
+
gamma = 0.99
|
| 232 |
+
|
| 233 |
+
episodes = 100
|
| 234 |
+
for episode in range(episodes):
|
| 235 |
+
state = states[np.random.randint(0, len(states))] # Random start
|
| 236 |
+
token_prices = state[-len(state[5:]):] # Extract token prices from state
|
| 237 |
+
token_balances = state[5:5 + len(state[5:])] # Extract token balances from state
|
| 238 |
+
steps = 0
|
| 239 |
+
done = False
|
| 240 |
+
|
| 241 |
+
while not done:
|
| 242 |
+
if np.random.rand() < 0.1: # Exploration
|
| 243 |
+
action = np.random.choice(action_size)
|
| 244 |
+
else: # Exploitation
|
| 245 |
+
action = np.argmax(model.predict(state.reshape(1, -1))[0])
|
| 246 |
+
|
| 247 |
+
next_state, token_prices = simulate_next_state(state, action, token_prices, token_balances)
|
| 248 |
+
diversification_score = len(set(token_balances)) / len(token_balances) # Diversity metric
|
| 249 |
+
reward = calculate_reward(state[0], next_state[0], diversification_score)
|
| 250 |
+
done = check_termination(next_state, steps, max_steps=50)
|
| 251 |
+
|
| 252 |
+
replay_buffer.append((state, action, reward, next_state, done))
|
| 253 |
+
|
| 254 |
+
if len(replay_buffer) > 32:
|
| 255 |
+
minibatch = random.sample(replay_buffer, 32)
|
| 256 |
+
states_mb, actions_mb, rewards_mb, next_states_mb, dones_mb = zip(*minibatch)
|
| 257 |
+
|
| 258 |
+
targets = model.predict(np.array(states_mb))
|
| 259 |
+
next_q_values = model.predict(np.array(next_states_mb))
|
| 260 |
+
|
| 261 |
+
for i in range(32):
|
| 262 |
+
if dones_mb[i]:
|
| 263 |
+
targets[i][actions_mb[i]] = rewards_mb[i]
|
| 264 |
+
else:
|
| 265 |
+
targets[i][actions_mb[i]] = rewards_mb[i] + gamma * np.max(next_q_values[i])
|
| 266 |
+
|
| 267 |
+
model.fit(np.array(states_mb), targets, epochs=1, verbose=0)
|
| 268 |
+
|
| 269 |
+
state = next_state
|
| 270 |
+
steps += 1
|
| 271 |
+
|
| 272 |
+
print(f"Episode {episode + 1}/{episodes} completed.")
|
| 273 |
+
|
| 274 |
+
model.save("rl_wallet_model.h5")
|
| 275 |
+
print("Model training complete and saved.")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
API_URL = "https://soltrendio.com/api/stats/getTrends" # Replace with your actual API URL
|
| 280 |
+
train_rl_model(API_URL)
|