Ivan000 commited on
Commit
3c22597
·
verified ·
1 Parent(s): 62ea9a7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # =============
3
+ # This is a complete app.py file for an Arkanoid game that a neural network will play and learn using reinforcement learning.
4
+ # The game is built using pygame, and the neural network is trained using stable-baselines3. Gradio is used for the interface.
5
+
6
+ import os
7
+ import numpy as np
8
+ import pygame
9
+ import random
10
+ from stable_baselines3 import DQN
11
+ from stable_baselines3.common.env_util import make_atari_env
12
+ from stable_baselines3.common.vec_env import VecFrameStack
13
+ from stable_baselines3.common.evaluation import evaluate_policy
14
+ import gradio as gr
15
+
16
+ # Constants
17
+ SCREEN_WIDTH = 640
18
+ SCREEN_HEIGHT = 480
19
+ PADDLE_WIDTH = 100
20
+ PADDLE_HEIGHT = 10
21
+ BALL_RADIUS = 10
22
+ BRICK_WIDTH = 60
23
+ BRICK_HEIGHT = 20
24
+ BRICK_ROWS = 5
25
+ BRICK_COLS = 10
26
+ FPS = 60
27
+
28
+ # Colors
29
+ WHITE = (255, 255, 255)
30
+ BLACK = (0, 0, 0)
31
+ RED = (255, 0, 0)
32
+
33
+ # Initialize Pygame
34
+ pygame.init()
35
+ screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
36
+ pygame.display.set_caption("Arkanoid")
37
+
38
+ # Game classes
39
+ class Paddle:
40
+ def __init__(self):
41
+ self.rect = pygame.Rect(SCREEN_WIDTH // 2 - PADDLE_WIDTH // 2, SCREEN_HEIGHT - PADDLE_HEIGHT - 10, PADDLE_WIDTH, PADDLE_HEIGHT)
42
+
43
+ def move(self, direction):
44
+ if direction == -1:
45
+ self.rect.x -= 10
46
+ elif direction == 1:
47
+ self.rect.x += 10
48
+ self.rect.clamp_ip(pygame.Rect(0, 0, SCREEN_WIDTH, SCREEN_HEIGHT))
49
+
50
+ class Ball:
51
+ def __init__(self):
52
+ self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
53
+ self.velocity = [random.choice([-5, 5]), -5]
54
+
55
+ def move(self):
56
+ self.rect.x += self.velocity[0]
57
+ self.rect.y += self.velocity[1]
58
+
59
+ if self.rect.left <= 0 or self.rect.right >= SCREEN_WIDTH:
60
+ self.velocity[0] = -self.velocity[0]
61
+ if self.rect.top <= 0:
62
+ self.velocity[1] = -self.velocity[1]
63
+
64
+ def reset(self):
65
+ self.rect = pygame.Rect(SCREEN_WIDTH // 2 - BALL_RADIUS, SCREEN_HEIGHT // 2 - BALL_RADIUS, BALL_RADIUS * 2, BALL_RADIUS * 2)
66
+ self.velocity = [random.choice([-5, 5]), -5]
67
+
68
+ class Brick:
69
+ def __init__(self, x, y):
70
+ self.rect = pygame.Rect(x, y, BRICK_WIDTH, BRICK_HEIGHT)
71
+
72
+ class ArkanoidEnv:
73
+ def __init__(self):
74
+ self.paddle = Paddle()
75
+ self.ball = Ball()
76
+ self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT) for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
77
+ self.clock = pygame.time.Clock()
78
+ self.done = False
79
+ self.score = 0
80
+
81
+ def reset(self):
82
+ self.paddle = Paddle()
83
+ self.ball = Ball()
84
+ self.bricks = [Brick(x, y) for y in range(BRICK_HEIGHT, BRICK_HEIGHT * (BRICK_ROWS + 1), BRICK_HEIGHT) for x in range(BRICK_WIDTH, SCREEN_WIDTH - BRICK_WIDTH, BRICK_WIDTH)]
85
+ self.done = False
86
+ self.score = 0
87
+ return self._get_state()
88
+
89
+ def step(self, action):
90
+ self.paddle.move(action)
91
+ self.ball.move()
92
+
93
+ if self.ball.rect.colliderect(self.paddle.rect):
94
+ self.ball.velocity[1] = -self.ball.velocity[1]
95
+
96
+ for brick in self.bricks[:]:
97
+ if self.ball.rect.colliderect(brick.rect):
98
+ self.bricks.remove(brick)
99
+ self.ball.velocity[1] = -self.ball.velocity[1]
100
+ self.score += 1
101
+
102
+ if self.ball.rect.bottom >= SCREEN_HEIGHT:
103
+ self.done = True
104
+
105
+ if not self.bricks:
106
+ self.done = True
107
+
108
+ reward = 1 if self.score > 0 else -1
109
+ return self._get_state(), reward, self.done, {}
110
+
111
+ def _get_state(self):
112
+ state = [
113
+ self.paddle.rect.x,
114
+ self.ball.rect.x,
115
+ self.ball.rect.y,
116
+ self.ball.velocity[0],
117
+ self.ball.velocity[1]
118
+ ]
119
+ for brick in self.bricks:
120
+ state.extend([brick.rect.x, brick.rect.y])
121
+ return np.array(state, dtype=np.float32)
122
+
123
+ def render(self):
124
+ screen.fill(BLACK)
125
+ pygame.draw.rect(screen, WHITE, self.paddle.rect)
126
+ pygame.draw.ellipse(screen, WHITE, self.ball.rect)
127
+ for brick in self.bricks:
128
+ pygame.draw.rect(screen, RED, brick.rect)
129
+ pygame.display.flip()
130
+ self.clock.tick(FPS)
131
+
132
+ # Training function
133
+ def train_model():
134
+ env = ArkanoidEnv()
135
+ model = DQN('MlpPolicy', env, verbose=1)
136
+ model.learn(total_timesteps=10000)
137
+ model.save("arkanoid_model")
138
+ return model
139
+
140
+ # Evaluation function
141
+ def evaluate_model(model):
142
+ env = ArkanoidEnv()
143
+ mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=10, render=False)
144
+ return mean_reward
145
+
146
+ # Gradio interface
147
+ def play_game():
148
+ env = ArkanoidEnv()
149
+ model = DQN.load("arkanoid_model")
150
+ obs = env.reset()
151
+ done = False
152
+ frames = []
153
+ while not done:
154
+ action, _states = model.predict(obs, deterministic=True)
155
+ obs, rewards, done, info = env.step(action)
156
+ env.render()
157
+ pygame.image.save(screen, "frame.png")
158
+ frames.append(gr.Image(value="frame.png"))
159
+ return frames
160
+
161
+ # Main function
162
+ def main():
163
+ if not os.path.exists("arkanoid_model.zip"):
164
+ print("Training model...")
165
+ train_model()
166
+ else:
167
+ print("Model already trained.")
168
+
169
+ print("Evaluating model...")
170
+ model = DQN.load("arkanoid_model")
171
+ mean_reward = evaluate_model(model)
172
+ print(f"Mean reward: {mean_reward}")
173
+
174
+ # Gradio interface
175
+ iface = gr.Interface(
176
+ fn=play_game,
177
+ inputs=None,
178
+ outputs="image",
179
+ live=True
180
+ )
181
+ iface.launch()
182
+
183
+ if __name__ == "__main__":
184
+ main()
185
+
186
+ # Dependencies
187
+ # =============
188
+ # The following dependencies are required to run this app:
189
+ # - pygame
190
+ # - stable-baselines3
191
+ # - torch
192
+ # - gradio
193
+ #
194
+ # You can install these dependencies using pip:
195
+ # pip install pygame stable-baselines3 torch gradio