Spaces:
Sleeping
Sleeping
Commenting out the testing part
Browse files- app.py +1 -1
- tester.py +0 -3
- trainer.py +5 -17
app.py
CHANGED
|
@@ -27,7 +27,7 @@ def main():
|
|
| 27 |
|
| 28 |
if start_button:
|
| 29 |
agent = perform_training(jammer_type, channel_switching_cost)
|
| 30 |
-
test(agent, jammer_type, channel_switching_cost)
|
| 31 |
|
| 32 |
|
| 33 |
def perform_training(jammer_type, channel_switching_cost):
|
|
|
|
| 27 |
|
| 28 |
if start_button:
|
| 29 |
agent = perform_training(jammer_type, channel_switching_cost)
|
| 30 |
+
# test(agent, jammer_type, channel_switching_cost)
|
| 31 |
|
| 32 |
|
| 33 |
def perform_training(jammer_type, channel_switching_cost):
|
tester.py
CHANGED
|
@@ -2,10 +2,7 @@
|
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
-
import matplotlib.pyplot as plt
|
| 6 |
-
import json
|
| 7 |
import streamlit as st
|
| 8 |
-
from DDQN import DoubleDeepQNetwork
|
| 9 |
from antiJamEnv import AntiJamEnv
|
| 10 |
|
| 11 |
|
|
|
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
|
| 4 |
import numpy as np
|
|
|
|
|
|
|
| 5 |
import streamlit as st
|
|
|
|
| 6 |
from antiJamEnv import AntiJamEnv
|
| 7 |
|
| 8 |
|
trainer.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
-
import json
|
| 7 |
import streamlit as st
|
| 8 |
from DDQN import DoubleDeepQNetwork
|
| 9 |
from antiJamEnv import AntiJamEnv
|
|
@@ -108,30 +107,18 @@ def train(jammer_type, channel_switching_cost):
|
|
| 108 |
st.subheader("Graph Explanation")
|
| 109 |
st.write(insights)
|
| 110 |
|
| 111 |
-
# Save the figure
|
| 112 |
-
# plot_name = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.png'
|
| 113 |
-
# plt.savefig(plot_name, bbox_inches='tight')
|
| 114 |
plt.close(fig) # Close the figure to release resources
|
| 115 |
|
| 116 |
-
# Save Results
|
| 117 |
-
# Rewards
|
| 118 |
-
# fileName = f'./data/train_rewards_{jammer_type}_csc_{channel_switching_cost}.json'
|
| 119 |
-
# with open(fileName, 'w') as f:
|
| 120 |
-
# json.dump(rewards, f)
|
| 121 |
-
#
|
| 122 |
-
# # Save the agent as a SavedAgent.
|
| 123 |
-
# agentName = f'./data/DDQNAgent_{jammer_type}_csc_{channel_switching_cost}'
|
| 124 |
-
# DDQN_agent.save_model(agentName)
|
| 125 |
return DDQN_agent
|
| 126 |
|
| 127 |
|
| 128 |
def generate_insights_langchain(rewards, rolling_average, epsilons, solved_threshold):
|
| 129 |
data_description = (
|
| 130 |
f"The graph represents training rewards over episodes. "
|
| 131 |
-
f"The actual rewards range from {min(rewards)} to {max(rewards)} with an average of {np.mean(rewards):.2f}. "
|
| 132 |
-
f"The rolling average values range from {min(rolling_average)} to {max(rolling_average)} with an average of {np.mean(rolling_average):.2f}. "
|
| 133 |
-
f"The epsilon values range from {min(epsilons)} to {max(epsilons)} with an average exploration rate of {np.mean(epsilons):.2f}. "
|
| 134 |
-
f"The solved threshold is set at {solved_threshold}."
|
| 135 |
)
|
| 136 |
|
| 137 |
result = llm_chain.predict(data=data_description)
|
|
@@ -139,3 +126,4 @@ def generate_insights_langchain(rewards, rolling_average, epsilons, solved_thres
|
|
| 139 |
|
| 140 |
|
| 141 |
|
|
|
|
|
|
| 3 |
|
| 4 |
import numpy as np
|
| 5 |
import matplotlib.pyplot as plt
|
|
|
|
| 6 |
import streamlit as st
|
| 7 |
from DDQN import DoubleDeepQNetwork
|
| 8 |
from antiJamEnv import AntiJamEnv
|
|
|
|
| 107 |
st.subheader("Graph Explanation")
|
| 108 |
st.write(insights)
|
| 109 |
|
|
|
|
|
|
|
|
|
|
| 110 |
plt.close(fig) # Close the figure to release resources
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
return DDQN_agent
|
| 113 |
|
| 114 |
|
| 115 |
def generate_insights_langchain(rewards, rolling_average, epsilons, solved_threshold):
|
| 116 |
data_description = (
|
| 117 |
f"The graph represents training rewards over episodes. "
|
| 118 |
+
f"The actual rewards range from {min(rewards):.2f} to {max(rewards):.2f} with an average of {np.mean(rewards):.2f}. "
|
| 119 |
+
f"The rolling average values range from {min(rolling_average):.2f} to {max(rolling_average):.2f} with an average of {np.mean(rolling_average):.2f}. "
|
| 120 |
+
f"The epsilon values range from {min(epsilons):.2f} to {max(epsilons):.2f} with an average exploration rate of {np.mean(epsilons):.2f}. "
|
| 121 |
+
f"The solved threshold is set at {solved_threshold:.2f}."
|
| 122 |
)
|
| 123 |
|
| 124 |
result = llm_chain.predict(data=data_description)
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
|
| 129 |
+
|