Spaces:
Runtime error
Runtime error
File size: 5,989 Bytes
78ea89a 141636f 78ea89a 141636f 78ea89a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from flask import Flask, render_template
from flask_socketio import SocketIO, emit
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from PIL import Image, ImageDraw
import io # Changed this line - io is a built-in Python module
import time
import threading
import random
app = Flask(__name__)
socketio = SocketIO(app)
# Initialize model with lower precision
MODEL_NAME = "Qwen/Qwen-1_8B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Game Constants
GRID_SIZE = 12 # Smaller grid for performance
CELL_SIZE = 40
COLORS = {
'background': 'white',
'grid': 'lightgray',
'snake': 'red',
'agent': 'blue',
'obstacle': 'gray'
}
class GameState:
def __init__(self):
self.snake = [6, 6] # Center
self.agents = [[2, 2], [9, 9], [2, 9]]
self.obstacles = [[4, 4], [7, 7], [4, 7]]
self.scores = {'snake': 0, 'agents': 0}
self.history = []
def get_agent_state(self, agent_idx):
return {
'position': self.agents[agent_idx],
'snake_pos': self.snake,
'other_agents': [pos for i, pos in enumerate(self.agents) if i != agent_idx],
'obstacles': self.obstacles
}
game = GameState()
def get_model_decision(role, state):
"""Get next move from Qwen model."""
if role == "snake":
prompt = f"You are a predator trying to catch prey. Your position is {state['position']}, prey positions are {state['other_agents']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word."
else:
prompt = f"You are prey avoiding a predator. Your position is {state['position']}, predator position is {state['snake_pos']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word."
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=10,
temperature=0.7,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract move from response
moves = ["UP", "DOWN", "LEFT", "RIGHT", "STAY"]
for move in moves:
if move in response.upper():
return move
return "STAY"
def apply_move(position, move):
"""Apply move while respecting grid boundaries."""
x, y = position.copy()
if move == "UP" and y > 0:
y -= 1
elif move == "DOWN" and y < GRID_SIZE - 1:
y += 1
elif move == "LEFT" and x > 0:
x -= 1
elif move == "RIGHT" and x < GRID_SIZE - 1:
x += 1
return [x, y]
def create_game_image():
"""Create game visualization."""
img = Image.new("RGB", (GRID_SIZE * CELL_SIZE, GRID_SIZE * CELL_SIZE), COLORS['background'])
draw = ImageDraw.Draw(img)
# Draw grid
for i in range(GRID_SIZE + 1):
draw.line([(i * CELL_SIZE, 0), (i * CELL_SIZE, GRID_SIZE * CELL_SIZE)], fill=COLORS['grid'])
draw.line([(0, i * CELL_SIZE), (GRID_SIZE * CELL_SIZE, i * CELL_SIZE)], fill=COLORS['grid'])
# Draw obstacles
for pos in game.obstacles:
draw.rectangle([
pos[0] * CELL_SIZE, pos[1] * CELL_SIZE,
(pos[0] + 1) * CELL_SIZE, (pos[1] + 1) * CELL_SIZE
], fill=COLORS['obstacle'])
# Draw agents
for pos in game.agents:
center = ((pos[0] + 0.5) * CELL_SIZE, (pos[1] + 0.5) * CELL_SIZE)
radius = CELL_SIZE // 3
draw.ellipse([
center[0] - radius, center[1] - radius,
center[0] + radius, center[1] + radius
], fill=COLORS['agent'])
# Draw snake
center = ((game.snake[0] + 0.5) * CELL_SIZE, (game.snake[1] + 0.5) * CELL_SIZE)
radius = CELL_SIZE // 3
draw.ellipse([
center[0] - radius, center[1] - radius,
center[0] + radius, center[1] + radius
], fill=COLORS['snake'])
# Add scores
draw.text((10, 10), f"Snake: {game.scores['snake']} | Agents: {game.scores['agents']}", fill="black")
# Convert to bytes
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
return img_byte_arr
def update_game():
"""Update game state for one turn."""
# Snake's turn
snake_state = {'position': game.snake, 'other_agents': game.agents}
snake_move = get_model_decision('snake', snake_state)
new_pos = apply_move(game.snake, snake_move)
if new_pos not in game.obstacles:
game.snake = new_pos
# Agents' turns
for i in range(len(game.agents)):
agent_state = game.get_agent_state(i)
agent_move = get_model_decision('agent', agent_state)
new_pos = apply_move(game.agents[i], agent_move)
if new_pos not in game.obstacles:
game.agents[i] = new_pos
# Check captures
for i, agent_pos in enumerate(game.agents):
if agent_pos == game.snake:
game.scores['snake'] += 1
# Respawn agent
while True:
new_pos = [random.randint(0, GRID_SIZE - 1), random.randint(0, GRID_SIZE - 1)]
if new_pos not in game.obstacles and new_pos != game.snake:
game.agents[i] = new_pos
break
def game_loop():
"""Main game loop."""
while True:
update_game()
img_bytes = create_game_image()
socketio.emit('game_update', {
'image': img_bytes.getvalue().hex(),
'scores': game.scores
})
time.sleep(1.0) # Slower updates to reduce resource usage
@app.route('/')
def index():
return render_template('index.html')
@socketio.on('connect')
def handle_connect():
print('Client connected')
if __name__ == '__main__':
threading.Thread(target=game_loop, daemon=True).start()
socketio.run(app, host='0.0.0.0', port=7860) |