SmokeyBandit commited on
Commit
78ea89a
·
verified ·
1 Parent(s): 3263cc6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template
2
+ from flask_socketio import SocketIO, emit
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import numpy as np
6
+ from PIL import Image, ImageDraw, io
7
+ import time
8
+ import threading
9
+ import random
10
+
11
+ app = Flask(__name__)
12
+ socketio = SocketIO(app)
13
+
14
+ # Initialize model with lower precision
15
+ MODEL_NAME = "Qwen/Qwen-1_5B-Chat"
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_NAME,
19
+ torch_dtype=torch.float16,
20
+ device_map="auto",
21
+ trust_remote_code=True
22
+ )
23
+
24
+ # Game Constants
25
+ GRID_SIZE = 12 # Smaller grid for performance
26
+ CELL_SIZE = 40
27
+ COLORS = {
28
+ 'background': 'white',
29
+ 'grid': 'lightgray',
30
+ 'snake': 'red',
31
+ 'agent': 'blue',
32
+ 'obstacle': 'gray'
33
+ }
34
+
35
+ class GameState:
36
+ def __init__(self):
37
+ self.snake = [6, 6] # Center
38
+ self.agents = [[2, 2], [9, 9], [2, 9]]
39
+ self.obstacles = [[4, 4], [7, 7], [4, 7]]
40
+ self.scores = {'snake': 0, 'agents': 0}
41
+ self.history = []
42
+
43
+ def get_agent_state(self, agent_idx):
44
+ return {
45
+ 'position': self.agents[agent_idx],
46
+ 'snake_pos': self.snake,
47
+ 'other_agents': [pos for i, pos in enumerate(self.agents) if i != agent_idx],
48
+ 'obstacles': self.obstacles
49
+ }
50
+
51
+ game = GameState()
52
+
53
+ def get_model_decision(role, state):
54
+ """Get next move from Qwen model."""
55
+ if role == "snake":
56
+ prompt = f"You are a predator trying to catch prey. Your position is {state['position']}, prey positions are {state['other_agents']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word."
57
+ else:
58
+ prompt = f"You are prey avoiding a predator. Your position is {state['position']}, predator position is {state['snake_pos']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word."
59
+
60
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
61
+ outputs = model.generate(
62
+ **inputs,
63
+ max_new_tokens=10,
64
+ temperature=0.7,
65
+ do_sample=True
66
+ )
67
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
+
69
+ # Extract move from response
70
+ moves = ["UP", "DOWN", "LEFT", "RIGHT", "STAY"]
71
+ for move in moves:
72
+ if move in response.upper():
73
+ return move
74
+ return "STAY"
75
+
76
+ def apply_move(position, move):
77
+ """Apply move while respecting grid boundaries."""
78
+ x, y = position.copy()
79
+ if move == "UP" and y > 0:
80
+ y -= 1
81
+ elif move == "DOWN" and y < GRID_SIZE - 1:
82
+ y += 1
83
+ elif move == "LEFT" and x > 0:
84
+ x -= 1
85
+ elif move == "RIGHT" and x < GRID_SIZE - 1:
86
+ x += 1
87
+ return [x, y]
88
+
89
+ def create_game_image():
90
+ """Create game visualization."""
91
+ img = Image.new("RGB", (GRID_SIZE * CELL_SIZE, GRID_SIZE * CELL_SIZE), COLORS['background'])
92
+ draw = ImageDraw.Draw(img)
93
+
94
+ # Draw grid
95
+ for i in range(GRID_SIZE + 1):
96
+ draw.line([(i * CELL_SIZE, 0), (i * CELL_SIZE, GRID_SIZE * CELL_SIZE)], fill=COLORS['grid'])
97
+ draw.line([(0, i * CELL_SIZE), (GRID_SIZE * CELL_SIZE, i * CELL_SIZE)], fill=COLORS['grid'])
98
+
99
+ # Draw obstacles
100
+ for pos in game.obstacles:
101
+ draw.rectangle([
102
+ pos[0] * CELL_SIZE, pos[1] * CELL_SIZE,
103
+ (pos[0] + 1) * CELL_SIZE, (pos[1] + 1) * CELL_SIZE
104
+ ], fill=COLORS['obstacle'])
105
+
106
+ # Draw agents
107
+ for pos in game.agents:
108
+ center = ((pos[0] + 0.5) * CELL_SIZE, (pos[1] + 0.5) * CELL_SIZE)
109
+ radius = CELL_SIZE // 3
110
+ draw.ellipse([
111
+ center[0] - radius, center[1] - radius,
112
+ center[0] + radius, center[1] + radius
113
+ ], fill=COLORS['agent'])
114
+
115
+ # Draw snake
116
+ center = ((game.snake[0] + 0.5) * CELL_SIZE, (game.snake[1] + 0.5) * CELL_SIZE)
117
+ radius = CELL_SIZE // 3
118
+ draw.ellipse([
119
+ center[0] - radius, center[1] - radius,
120
+ center[0] + radius, center[1] + radius
121
+ ], fill=COLORS['snake'])
122
+
123
+ # Add scores
124
+ draw.text((10, 10), f"Snake: {game.scores['snake']} | Agents: {game.scores['agents']}", fill="black")
125
+
126
+ # Convert to bytes
127
+ img_byte_arr = io.BytesIO()
128
+ img.save(img_byte_arr, format='PNG')
129
+ img_byte_arr.seek(0)
130
+ return img_byte_arr
131
+
132
+ def update_game():
133
+ """Update game state for one turn."""
134
+ # Snake's turn
135
+ snake_state = {'position': game.snake, 'other_agents': game.agents}
136
+ snake_move = get_model_decision('snake', snake_state)
137
+ new_pos = apply_move(game.snake, snake_move)
138
+ if new_pos not in game.obstacles:
139
+ game.snake = new_pos
140
+
141
+ # Agents' turns
142
+ for i in range(len(game.agents)):
143
+ agent_state = game.get_agent_state(i)
144
+ agent_move = get_model_decision('agent', agent_state)
145
+ new_pos = apply_move(game.agents[i], agent_move)
146
+ if new_pos not in game.obstacles:
147
+ game.agents[i] = new_pos
148
+
149
+ # Check captures
150
+ for i, agent_pos in enumerate(game.agents):
151
+ if agent_pos == game.snake:
152
+ game.scores['snake'] += 1
153
+ # Respawn agent
154
+ while True:
155
+ new_pos = [random.randint(0, GRID_SIZE - 1), random.randint(0, GRID_SIZE - 1)]
156
+ if new_pos not in game.obstacles and new_pos != game.snake:
157
+ game.agents[i] = new_pos
158
+ break
159
+
160
+ def game_loop():
161
+ """Main game loop."""
162
+ while True:
163
+ update_game()
164
+ img_bytes = create_game_image()
165
+ socketio.emit('game_update', {
166
+ 'image': img_bytes.getvalue().hex(),
167
+ 'scores': game.scores
168
+ })
169
+ time.sleep(1.0) # Slower updates to reduce resource usage
170
+
171
+ @app.route('/')
172
+ def index():
173
+ return render_template('index.html')
174
+
175
+ @socketio.on('connect')
176
+ def handle_connect():
177
+ print('Client connected')
178
+
179
+ if __name__ == '__main__':
180
+ threading.Thread(target=game_loop, daemon=True).start()
181
+ socketio.run(app, host='0.0.0.0', port=7860)