jonngan commited on
Commit
24c3333
·
verified ·
1 Parent(s): b0fa50c

Upload 4 files

Browse files
Files changed (4) hide show
  1. ddqn.py +239 -0
  2. rl_wallet_model.h5 +3 -0
  3. test.py +122 -0
  4. train.py +280 -0
ddqn.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow.keras import layers
4
+ from pymongo import MongoClient
5
+ import requests
6
+ import random
7
+ from sklearn.preprocessing import StandardScaler
8
+ import os
9
+ from dotenv import load_dotenv
10
+ from langchain_openai import OpenAI
11
+ from langchain.prompts import PromptTemplate
12
+ from operator import itemgetter
13
+
14
+ # Load environment variables from .env file
15
+ load_dotenv()
16
+ API_URL = "https://soltrendio.com/api/stats/getTrends" # Replace with your actual API URL
17
+ MODEL_PATH = "rl_wallet_model.h5" # Path to the trained model
18
+ # MongoDB Connection Setup
19
+ def get_mongo_connection():
20
+ MONGO_URI = os.getenv("MONGODB_URI")
21
+ print(f"Attempting to connect to MongoDB with URI: {MONGO_URI[:20]}...") # Only show start of URI for security
22
+
23
+ DB_NAME = "walletAnalyzer"
24
+ COLLECTION_NAME = "wallets"
25
+
26
+ client = MongoClient(
27
+ MONGO_URI,
28
+ tls=True,
29
+ tlsAllowInvalidCertificates=True
30
+ )
31
+ db = client[DB_NAME]
32
+ print("Successfully connected to MongoDB")
33
+ return db[COLLECTION_NAME]
34
+
35
+ collection = get_mongo_connection()
36
+
37
+ # Fetch Aggregated Data from API
38
+ def fetch_aggregated_data(api_url):
39
+ print(f"Fetching data from API: {api_url}")
40
+ response = requests.get(api_url)
41
+ print(f"API Response status code: {response.status_code}")
42
+ return response.json()
43
+
44
+ # Extract States from Wallet Data
45
+ def extract_state(wallet, api_data):
46
+ print(f"\nExtracting state for wallet: {wallet.get('address', 'Unknown address')}")
47
+
48
+ if wallet.get("topHoldings") is None or not wallet["topHoldings"]:
49
+ print("Warning: Wallet has no topHoldings")
50
+ return None
51
+
52
+ print(f"Number of top holdings: {len(wallet['topHoldings'])}")
53
+
54
+ # Pad or truncate topHoldings to exactly 8 tokens
55
+ MAX_TOKENS = 8
56
+ padded_holdings = wallet['topHoldings'][:MAX_TOKENS] # Truncate if more than 8
57
+ while len(padded_holdings) < MAX_TOKENS: # Pad with empty holdings if less than 8
58
+ padded_holdings.append({
59
+ 'marketCap': '0',
60
+ 'price': '0',
61
+ 'balance': '0',
62
+ 'symbol': 'EMPTY'
63
+ })
64
+
65
+ market_caps = [float(h.get('marketCap', 0)) for h in padded_holdings]
66
+ token_prices = [float(h.get('price', 0)) for h in padded_holdings]
67
+ token_balances = [float(h.get('balance', 0)) for h in padded_holdings]
68
+
69
+ print(f"Market caps: {market_caps}")
70
+ print(f"Token prices: {token_prices}")
71
+ print(f"Token balances: {token_balances}")
72
+
73
+ avg_market_cap = sum(market_caps) / len(market_caps) if market_caps else 0
74
+ print(f"Average market cap: {avg_market_cap}")
75
+
76
+ # Ensure totalValue is a float
77
+ total_value = wallet['totalValue']
78
+ if isinstance(total_value, dict) and '$numberDouble' in total_value:
79
+ total_value = float(total_value['$numberDouble'])
80
+ print(f"Total wallet value: {total_value}")
81
+
82
+ # Create state vector with exactly 25 features:
83
+ # 5 portfolio metrics + (8 tokens × 2.5 features per token = 20 features)
84
+ state = [
85
+ total_value, # Total portfolio value
86
+ len(wallet['topHoldings']), # Original number of holdings
87
+ avg_market_cap, # Average market cap of top holdings
88
+ api_data['portfolioMetrics']['averagePortfolioValue'], # Average portfolio value
89
+ api_data['portfolioMetrics']['totalPortfolioValue'] # Total portfolio value
90
+ ] + token_balances + token_prices # 8 balances + 8 prices = 16 features
91
+
92
+ # Total features: 5 + 8 + 8 = 21 features
93
+ # Add 4 more features to reach 25 (you might want to adjust these based on your model's requirements)
94
+ state.extend([0.0] * 4) # Padding with zeros to reach 25 features
95
+
96
+ print(f"Final state vector shape: {len(state)}")
97
+ print(f"State vector: {state}\n")
98
+ return np.array(state)
99
+
100
+ # Normalize Features
101
+ def normalize_states(states, scaler=None):
102
+ print("\nNormalizing states...")
103
+ # Convert list to numpy array if it isn't already
104
+ states = np.array(states)
105
+ print(f"Input states shape: {states.shape}")
106
+
107
+ if scaler is None:
108
+ scaler = StandardScaler()
109
+ states = scaler.fit_transform(states)
110
+ print("Created new scaler and fit_transformed states")
111
+ else:
112
+ states = scaler.transform(states)
113
+ print("Used existing scaler to transform states")
114
+
115
+ print(f"Normalized states shape: {states.shape}")
116
+ print(f"Normalized states sample: {states[0][:5]}...\n")
117
+ return states, scaler
118
+
119
+ # Add this new function to create the summary
120
+ def generate_trading_summary(wallet_address, holdings_analysis):
121
+ """
122
+ Generate a natural language summary of trading recommendations using Langchain.
123
+ """
124
+ try:
125
+ llm = OpenAI(temperature=0.7, openai_api_key=os.getenv("OPENAI_API_KEY"))
126
+
127
+ # Create a prompt template
128
+ template = """
129
+ As a crypto trading advisor, analyze the following wallet and its DDQN holdings analysis:
130
+
131
+ Wallet Address: {wallet_address}
132
+
133
+ Holdings Analysis:
134
+ {holdings_details}
135
+
136
+ Please provide a concise summary of the recommended trading actions, including:
137
+ 1. Overall portfolio assessment
138
+ 2. Specific recommendations for each token
139
+ 3. Key opportunities and risks
140
+
141
+ Summary:
142
+ """
143
+
144
+ # Create prompt with template
145
+ prompt = PromptTemplate(
146
+ input_variables=["wallet_address", "holdings_details"],
147
+ template=template
148
+ )
149
+
150
+ # Create and run the chain using the new pattern
151
+ chain = (
152
+ {"wallet_address": itemgetter("wallet_address"),
153
+ "holdings_details": itemgetter("holdings_details")}
154
+ | prompt
155
+ | llm
156
+ )
157
+
158
+ # Invoke the chain with the new pattern
159
+ summary = chain.invoke({
160
+ "wallet_address": wallet_address,
161
+ "holdings_details": holdings_analysis
162
+ })
163
+
164
+ print("\n=== AI Generated Trading Summary ===")
165
+ print(summary)
166
+ return summary
167
+
168
+ except Exception as e:
169
+ print(f"Error generating summary: {str(e)}")
170
+ return None
171
+
172
+ # Load and Test the Trained Model
173
+ def test_model_on_wallet(wallet_identifier, api_url, model_path):
174
+ # Load the trained model without compilation
175
+ model = tf.keras.models.load_model(model_path, compile=False)
176
+ print("Model loaded successfully.")
177
+
178
+ # Recompile the model with the correct loss function
179
+ model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
180
+ print("Model recompiled with MSE loss.")
181
+
182
+ # Fetch API data
183
+ api_data = fetch_aggregated_data(api_url)
184
+
185
+ # Fetch wallet from database
186
+ wallet = collection.find_one({"$or": [{"address": wallet_identifier}, {"domain": wallet_identifier}]})
187
+ if wallet is None:
188
+ print(f"Wallet with identifier {wallet_identifier} not found in the database.")
189
+ return
190
+
191
+ # Extract state
192
+ state = extract_state(wallet, api_data)
193
+ if state is None:
194
+ print(f"Wallet {wallet_identifier} has no topHoldings or insufficient data.")
195
+ return
196
+
197
+ # Normalize the state (use a pre-fitted scaler if available)
198
+ states, scaler = normalize_states([state])
199
+ normalized_state = states[0] # Extract the first (and only) normalized state
200
+
201
+ # Predict actions using the model
202
+ q_values = model.predict(normalized_state.reshape(1, -1))[0]
203
+
204
+ # Iterate over all tokens in the wallet and output the best action
205
+ print(f"\nAnalyzing wallet: {wallet['address']}")
206
+ print(f"Q-values: {q_values}")
207
+
208
+ # Create a list to store analysis results
209
+ holdings_analysis = []
210
+
211
+ for token_index, holding in enumerate(wallet["topHoldings"]):
212
+ # Determine the best action for this token
213
+ actions = ["Buy", "Sell", "Hold"]
214
+ token_q_values = q_values[token_index * 3: (token_index + 1) * 3]
215
+ best_action_index = np.argmax(token_q_values)
216
+ best_action = actions[best_action_index]
217
+
218
+ # Calculate confidence as the difference between the best Q-value and the average of others
219
+ confidence = token_q_values[best_action_index] - np.mean(token_q_values)
220
+
221
+ # Store the analysis
222
+ holdings_analysis.append({
223
+ "symbol": holding["symbol"],
224
+ "balance": holding.get("balance", 0),
225
+ "action": best_action,
226
+ "confidence": confidence,
227
+ "q_values": token_q_values.tolist()
228
+ })
229
+
230
+ # Output the action for the token
231
+ print(f"Token: {holding['symbol']}, Best action: {best_action}, Q-values: {token_q_values}")
232
+
233
+ # Generate and display the trading summary
234
+ generate_trading_summary(wallet['address'], holdings_analysis)
235
+
236
+
237
+ if __name__ == "__main__":
238
+ wallet_identifier = input("Enter the wallet address or .sol domain to analyze: ").strip()
239
+ test_model_on_wallet(wallet_identifier, API_URL, MODEL_PATH)
rl_wallet_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a3055e68bb0f8eeb44e06cba1cb8fd3219cb8f1d74f5a65fb1a26684d1a8f24
3
+ size 145360
test.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow.keras import layers
4
+ from pymongo import MongoClient
5
+ import requests
6
+ import random
7
+ from sklearn.preprocessing import StandardScaler
8
+ import os
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables from .env file
12
+ load_dotenv()
13
+
14
+ # MongoDB Connection Setup
15
+ def get_mongo_connection():
16
+ MONGO_URI = os.getenv("MONGODB_URI")
17
+ DB_NAME = "walletAnalyzer"
18
+ COLLECTION_NAME = "wallets"
19
+
20
+ client = MongoClient(
21
+ MONGO_URI,
22
+ tls=True,
23
+ tlsAllowInvalidCertificates=True
24
+ )
25
+ db = client[DB_NAME]
26
+ return db[COLLECTION_NAME]
27
+
28
+ collection = get_mongo_connection()
29
+
30
+ # Fetch Aggregated Data from API
31
+ def fetch_aggregated_data(api_url):
32
+ response = requests.get(api_url)
33
+ return response.json()
34
+
35
+ # Extract States from Wallet Data
36
+ def extract_state(wallet, api_data):
37
+ if wallet.get("topHoldings") is None or not wallet["topHoldings"]:
38
+ return None # Skip wallets without topHoldings
39
+
40
+ market_caps = [float(h.get('marketCap', 0)) for h in wallet['topHoldings']]
41
+ token_prices = [float(h.get('price', 0)) for h in wallet['topHoldings']]
42
+ token_balances = [float(h.get('balance', 0)) for h in wallet['topHoldings']]
43
+
44
+ avg_market_cap = sum(market_caps) / len(market_caps) if market_caps else 0
45
+
46
+ # Ensure totalValue is a float
47
+ total_value = wallet['totalValue']
48
+ if isinstance(total_value, dict) and '$numberDouble' in total_value:
49
+ total_value = float(total_value['$numberDouble'])
50
+
51
+ state = [
52
+ total_value, # Total portfolio value
53
+ len(wallet['topHoldings']), # Number of holdings
54
+ avg_market_cap, # Average market cap of top holdings
55
+ api_data['portfolioMetrics']['averagePortfolioValue'], # Average portfolio value
56
+ api_data['portfolioMetrics']['totalPortfolioValue'] # Total portfolio value
57
+ ] + token_balances + token_prices
58
+
59
+ return np.array(state)
60
+
61
+ # Normalize Features
62
+ def normalize_states(states, scaler=None):
63
+ if scaler is None:
64
+ scaler = StandardScaler()
65
+ states = scaler.fit_transform(states)
66
+ else:
67
+ states = scaler.transform(states)
68
+ return states, scaler
69
+
70
+ # Load and Test the Trained Model
71
+ def test_model_on_random_wallet(api_url, model_path):
72
+ # Load the trained model without compilation
73
+ model = tf.keras.models.load_model(model_path, compile=False)
74
+ print("Model loaded successfully.")
75
+
76
+ # Recompile the model with the correct loss function
77
+ model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
78
+ print("Model recompiled with MSE loss.")
79
+
80
+ # Fetch API data
81
+ api_data = fetch_aggregated_data(api_url)
82
+
83
+ # Retry fetching a wallet with topHoldings
84
+ max_retries = 10
85
+ for attempt in range(max_retries):
86
+ random_wallet = collection.aggregate([{"$sample": {"size": 1}}]).next()
87
+ state = extract_state(random_wallet, api_data)
88
+ if state is not None:
89
+ break
90
+ print(f"Attempt {attempt + 1}: Wallet has no topHoldings. Retrying...")
91
+ else:
92
+ print("Failed to fetch a valid wallet with topHoldings after multiple attempts.")
93
+ return
94
+
95
+ # Normalize the state (use a pre-fitted scaler if available)
96
+ states, scaler = normalize_states([state])
97
+ normalized_state = states[0] # Extract the first (and only) normalized state
98
+
99
+ # Predict actions using the model
100
+ q_values = model.predict(normalized_state.reshape(1, -1))[0]
101
+
102
+ # Iterate over all tokens in the wallet and output the best action
103
+ print(f"Testing on wallet: {random_wallet['address']}")
104
+ print(f"Q-values: {q_values}")
105
+
106
+ for token_index, holding in enumerate(random_wallet["topHoldings"]):
107
+ # Determine the best action for this token
108
+ actions = ["Buy", "Sell", "Hold"]
109
+ token_q_values = q_values[token_index * 3: (token_index + 1) * 3]
110
+ best_action_index = np.argmax(token_q_values)
111
+ best_action = actions[best_action_index]
112
+
113
+ # Output the action for the token
114
+ token_symbol = holding["symbol"]
115
+ print(f"Token: {token_symbol}, Best action: {best_action}, Q-values: {token_q_values}")
116
+
117
+ if __name__ == "__main__":
118
+ API_URL = "https://soltrendio.com/api/stats/getTrends" # Replace with your actual API URL
119
+ MODEL_PATH = "rl_wallet_model.h5" # Path to the trained model
120
+
121
+ # Test the model on a random wallet
122
+ test_model_on_random_wallet(API_URL, MODEL_PATH)
train.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ from tensorflow.keras import layers
4
+ from pymongo import MongoClient
5
+ import requests
6
+ import random
7
+ from collections import deque
8
+ from sklearn.preprocessing import StandardScaler
9
+ import os
10
+ from dotenv import load_dotenv
11
+
12
+ # Load environment variables from .env file
13
+ load_dotenv()
14
+
15
+ # MongoDB Connection Setup
16
+ # Define connection parameters and establish connection to MongoDB
17
+ def get_mongo_connection():
18
+ MONGO_URI = os.getenv("MONGODB_URI")
19
+ DB_NAME = "walletAnalyzer"
20
+ COLLECTION_NAME = "wallets"
21
+
22
+ client = MongoClient(
23
+ MONGO_URI,
24
+ tls=True,
25
+ tlsAllowInvalidCertificates=True
26
+ )
27
+ db = client[DB_NAME]
28
+ return db[COLLECTION_NAME]
29
+
30
+ collection = get_mongo_connection()
31
+
32
+ # DexScreener API URL
33
+ DEX_API_URL = "https://api.dexscreener.com/latest/dex/tokens"
34
+
35
+ # Fetch Aggregated Data from API
36
+ # Retrieve portfolio and token trends from external API
37
+ def fetch_aggregated_data(api_url):
38
+ response = requests.get(api_url)
39
+ return response.json()
40
+
41
+ # Retrieve Contract Address for Token
42
+ # Find the contract address of a Solana token by its ticker symbol
43
+ def get_solana_token_ca(ticker):
44
+ try:
45
+ url = f"https://api.dexscreener.com/latest/dex/search?q={ticker}"
46
+ response = requests.get(url)
47
+ response.raise_for_status()
48
+ data = response.json()
49
+
50
+ first_found_address = None
51
+ for pair in data.get('pairs', []):
52
+ if pair.get('chainId') == 'solana':
53
+ base_token = pair.get('baseToken', {})
54
+ quote_token = pair.get('quoteToken', {})
55
+
56
+ if first_found_address is None:
57
+ first_found_address = base_token.get('address') or quote_token.get('address')
58
+
59
+ if base_token.get('symbol').upper() == ticker.upper():
60
+ return base_token.get('address')
61
+ elif quote_token.get('symbol').upper() == ticker.upper():
62
+ return quote_token.get('address')
63
+
64
+ return first_found_address
65
+ except requests.RequestException as e:
66
+ print(f"Error fetching token data: {e}")
67
+ return None
68
+
69
+ # Fetch Market Data for Token
70
+ # Retrieve market cap and price data for a token contract address
71
+ def fetch_market_data(contract_address):
72
+ try:
73
+ response = requests.get(f"{DEX_API_URL}/{contract_address}")
74
+ if response.status_code == 200:
75
+ data = response.json()
76
+ if "pairs" in data and len(data["pairs"]) > 0:
77
+ pair = data["pairs"][0]
78
+ return {
79
+ "marketCap": pair.get("marketCap"),
80
+ "price": pair.get("priceUsd"),
81
+ }
82
+ else:
83
+ print(f"Failed to fetch data for contract address {contract_address}: {response.status_code}")
84
+ except Exception as e:
85
+ print(f"Error fetching market data for contract address {contract_address}: {e}")
86
+ return None
87
+
88
+ # Update Wallets in Database
89
+ # Enhance wallet data with token market information and contract addresses
90
+ def update_wallets():
91
+ wallets = collection.find()
92
+ for wallet in wallets:
93
+ updated = False
94
+ if wallet.get("topHoldings") is None or not wallet["topHoldings"]:
95
+ continue
96
+ for holding in wallet.get("topHoldings", []):
97
+ symbol = holding.get("symbol")
98
+ contract_address = holding.get("contractAddress")
99
+
100
+ if not contract_address:
101
+ contract_address = get_solana_token_ca(symbol)
102
+ if contract_address:
103
+ holding["contractAddress"] = contract_address
104
+ print(f"Updated {symbol} with contract address {contract_address}.")
105
+ updated = True
106
+
107
+ if "marketCap" not in holding or "price" not in holding:
108
+ market_data = fetch_market_data(contract_address)
109
+ if market_data:
110
+ holding["marketCap"] = market_data["marketCap"]
111
+ holding["price"] = market_data["price"]
112
+ updated = True
113
+ print(f"Updated {symbol} in wallet {wallet['address']} with marketCap and price.")
114
+
115
+ if updated:
116
+ collection.update_one({"_id": wallet["_id"]}, {"$set": {"topHoldings": wallet["topHoldings"]}})
117
+
118
+ # Extract States from Wallet Data
119
+ # Convert wallet data into numerical state vectors for RL training
120
+ def extract_state(wallet, api_data, max_tokens=10):
121
+ if wallet.get("topHoldings") is None or not wallet["topHoldings"]:
122
+ return None # Skip wallets without topHoldings
123
+
124
+ # Extract market caps, token prices, and balances
125
+ market_caps = [float(h.get('marketCap', 0)) for h in wallet['topHoldings'][:max_tokens]]
126
+ token_prices = [float(h.get('price', 0)) for h in wallet['topHoldings'][:max_tokens]]
127
+ token_balances = [float(h.get('balance', 0)) for h in wallet['topHoldings'][:max_tokens]]
128
+
129
+ # Pad with zeros to ensure fixed length
130
+ market_caps += [0] * (max_tokens - len(market_caps))
131
+ token_prices += [0] * (max_tokens - len(token_prices))
132
+ token_balances += [0] * (max_tokens - len(token_balances))
133
+
134
+ avg_market_cap = sum(market_caps) / len(market_caps) if market_caps else 0
135
+
136
+ # Ensure totalValue is a float
137
+ total_value = wallet['totalValue']
138
+ if isinstance(total_value, dict) and '$numberDouble' in total_value:
139
+ total_value = float(total_value['$numberDouble'])
140
+
141
+ state = [
142
+ total_value, # Total portfolio value
143
+ len(wallet['topHoldings']), # Number of holdings
144
+ avg_market_cap, # Average market cap of top holdings
145
+ api_data['portfolioMetrics']['averagePortfolioValue'], # Average portfolio value
146
+ api_data['portfolioMetrics']['totalPortfolioValue'], # Total portfolio value
147
+ ] + token_balances + token_prices # Include token balances and prices
148
+
149
+ return np.array(state)
150
+
151
+
152
+
153
+ # Normalize Features
154
+ # Standardize state features for improved model training
155
+ def normalize_states(states):
156
+ scaler = StandardScaler()
157
+ return scaler.fit_transform(states)
158
+
159
+ # Create DQN Model
160
+ # Define the neural network for the RL agent
161
+ def create_dqn(state_size, action_size):
162
+ model = tf.keras.Sequential([
163
+ layers.Dense(64, activation='relu', input_shape=(state_size,)),
164
+ layers.Dense(64, activation='relu'),
165
+ layers.Dense(action_size, activation='linear')
166
+ ])
167
+ model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse')
168
+ return model
169
+
170
+ # Simulate Next State
171
+ # Generate the next state based on the action taken and simulated market changes
172
+ def simulate_next_state(state, action, token_prices, token_balances):
173
+ token_idx = action // 3 # Determine token index
174
+ action_type = action % 3 # Determine action type (Buy, Sell, Hold)
175
+
176
+ if action_type == 0: # Buy
177
+ state[token_idx + 5] += 1 # Increase token balance
178
+ elif action_type == 1: # Sell
179
+ state[token_idx + 5] -= 1 # Decrease token balance
180
+
181
+ # Simulate market changes in token prices
182
+ token_prices = [price * np.random.uniform(0.95, 1.05) for price in token_prices]
183
+
184
+ # Update total portfolio value based on new token balances and prices
185
+ total_value = sum(balance * price for balance, price in zip(token_balances, token_prices))
186
+ state[0] = total_value # Update total portfolio value
187
+
188
+ return state, token_prices
189
+
190
+ # Termination Logic
191
+ # Define conditions for ending an RL episode
192
+ def check_termination(state, steps, max_steps):
193
+ if state[0] <= 0 or steps >= max_steps:
194
+ return True
195
+ return False
196
+
197
+ # Reward Function
198
+ # Calculate rewards based on portfolio performance and risk
199
+ def calculate_reward(old_value, new_value, diversification_score):
200
+ value_change = (new_value - old_value) / old_value if old_value > 0 else 0
201
+ reward = value_change - (1 - diversification_score) # Penalize lack of diversification
202
+ return reward
203
+
204
+ # Main Training Loop
205
+ # Train the RL agent by simulating interactions with the environment
206
+ # Main Training Loop
207
+ def train_rl_model(api_url):
208
+ update_wallets() # Ensure wallets are updated before training
209
+
210
+ data = collection.find() # Fetch data directly from updated collection
211
+ api_data = fetch_aggregated_data(api_url)
212
+
213
+ # Extract valid states, skipping None
214
+ states = [
215
+ state for wallet in data
216
+ if (state := extract_state(wallet, api_data)) is not None
217
+ ]
218
+
219
+ # Ensure states list is not empty before proceeding
220
+ if not states:
221
+ print("No valid states available for training.")
222
+ return
223
+
224
+ states = normalize_states(states)
225
+
226
+ state_size = len(states[0])
227
+ action_size = 3 * len(states[0][5:]) # Buy, Sell, Hold for each token
228
+
229
+ model = create_dqn(state_size, action_size)
230
+ replay_buffer = deque(maxlen=2000)
231
+ gamma = 0.99
232
+
233
+ episodes = 100
234
+ for episode in range(episodes):
235
+ state = states[np.random.randint(0, len(states))] # Random start
236
+ token_prices = state[-len(state[5:]):] # Extract token prices from state
237
+ token_balances = state[5:5 + len(state[5:])] # Extract token balances from state
238
+ steps = 0
239
+ done = False
240
+
241
+ while not done:
242
+ if np.random.rand() < 0.1: # Exploration
243
+ action = np.random.choice(action_size)
244
+ else: # Exploitation
245
+ action = np.argmax(model.predict(state.reshape(1, -1))[0])
246
+
247
+ next_state, token_prices = simulate_next_state(state, action, token_prices, token_balances)
248
+ diversification_score = len(set(token_balances)) / len(token_balances) # Diversity metric
249
+ reward = calculate_reward(state[0], next_state[0], diversification_score)
250
+ done = check_termination(next_state, steps, max_steps=50)
251
+
252
+ replay_buffer.append((state, action, reward, next_state, done))
253
+
254
+ if len(replay_buffer) > 32:
255
+ minibatch = random.sample(replay_buffer, 32)
256
+ states_mb, actions_mb, rewards_mb, next_states_mb, dones_mb = zip(*minibatch)
257
+
258
+ targets = model.predict(np.array(states_mb))
259
+ next_q_values = model.predict(np.array(next_states_mb))
260
+
261
+ for i in range(32):
262
+ if dones_mb[i]:
263
+ targets[i][actions_mb[i]] = rewards_mb[i]
264
+ else:
265
+ targets[i][actions_mb[i]] = rewards_mb[i] + gamma * np.max(next_q_values[i])
266
+
267
+ model.fit(np.array(states_mb), targets, epochs=1, verbose=0)
268
+
269
+ state = next_state
270
+ steps += 1
271
+
272
+ print(f"Episode {episode + 1}/{episodes} completed.")
273
+
274
+ model.save("rl_wallet_model.h5")
275
+ print("Model training complete and saved.")
276
+
277
+
278
+ if __name__ == "__main__":
279
+ API_URL = "https://soltrendio.com/api/stats/getTrends" # Replace with your actual API URL
280
+ train_rl_model(API_URL)