Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from snake_env import SnakeEnv | |
| from Cod_agent import DQNAgent | |
| def plot_results(scores, best_scores, avg_scores, epsilons): | |
| fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7)) | |
| fig.suptitle("Snake AI — Training Results", fontsize=14, fontweight="bold") | |
| games = list(range(1, len(scores) + 1)) | |
| # --- Top: scores --- | |
| ax1.set_facecolor("#0f1923") | |
| ax1.plot(games, scores, color="#888888", alpha=0.5, | |
| linewidth=1, label="Score per game") | |
| ax1.plot(games, best_scores, color="#2ed573", | |
| linewidth=2, label="Best score") | |
| ax1.plot(games, avg_scores, color="#ffa502", | |
| linewidth=2, linestyle="--", label="Avg last 50 games") | |
| ax1.set_ylabel("Score") | |
| ax1.legend(loc="upper left") | |
| ax1.grid(True, alpha=0.2) | |
| # --- Bottom: epsilon --- | |
| ax2.set_facecolor("#0f1923") | |
| ax2.plot(games, epsilons, color="#7f77dd", linewidth=2, label="Epsilon") | |
| ax2.set_ylabel("Epsilon") | |
| ax2.set_xlabel("Game") | |
| ax2.legend(loc="upper right") | |
| ax2.grid(True, alpha=0.2) | |
| plt.tight_layout() | |
| plt.savefig("training_progress.png", dpi=150, bbox_inches="tight") | |
| print("Graph saved → training_progress.png") | |
| plt.show() | |
| def train(): | |
| env = SnakeEnv(render=True) # no window during training | |
| agent = DQNAgent() | |
| num_games = 1000 | |
| best_score = 0 | |
| all_scores = [] | |
| best_scores = [] | |
| avg_scores = [] | |
| epsilons = [] | |
| print("Training started! Graph will show when done.\n") | |
| print(f"{'Game':<8} {'Score':<8} {'Best':<8} {'Avg(50)':<10} {'Epsilon'}") | |
| print("-" * 55) | |
| for game in range(1, num_games + 1): | |
| state = env.reset() | |
| while True: | |
| action = agent.act(state) | |
| next_state, reward, done = env.step(action) | |
| agent.remember(state, action, reward, next_state, done) | |
| agent.learn() | |
| state = next_state | |
| if done: | |
| break | |
| # --- Record stats --- | |
| all_scores.append(env.score) | |
| if env.score > best_score: | |
| best_score = env.score | |
| agent.save("snake_brain.npz") | |
| best_scores.append(best_score) | |
| avg = float(np.mean(all_scores[-50:])) | |
| avg_scores.append(avg) | |
| epsilons.append(agent.epsilon) | |
| if game % 10 == 0: | |
| print(f"{game:<8} {env.score:<8} {best_score:<8} " | |
| f"{avg:<10.2f} {agent.epsilon:.4f}") | |
| print(f"\nTraining complete! Best score: {best_score}") | |
| plot_results(all_scores, best_scores, avg_scores, epsilons) | |
| def watch(): | |
| """Load trained model and watch it play.""" | |
| env = SnakeEnv(render=True) | |
| agent = DQNAgent() | |
| agent.epsilon = 0.0 | |
| try: | |
| agent.load("Cod_brain.npz") | |
| except FileNotFoundError: | |
| print("No saved model found! Train first: python train.py") | |
| return | |
| print("Watching AI play... (close window to stop)\n") | |
| game = 0 | |
| while True: | |
| state = env.reset() | |
| game += 1 | |
| while True: | |
| action = agent.act(state) | |
| state, _, done = env.step(action) | |
| if done: | |
| print(f"Game {game} | Score: {env.score}") | |
| break | |
| # ================================================================ | |
| if __name__ == "__main__": | |
| import sys | |
| mode = sys.argv[1] if len(sys.argv) > 1 else "train" | |
| if mode == "watch": | |
| watch() | |
| else: | |
| train() |