Spaces:
Runtime error
Runtime error
| from flask import Flask, render_template | |
| from flask_socketio import SocketIO, emit | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import io # Changed this line - io is a built-in Python module | |
| import time | |
| import threading | |
| import random | |
| app = Flask(__name__) | |
| socketio = SocketIO(app) | |
| # Initialize model with lower precision | |
| MODEL_NAME = "Qwen/Qwen-1_8B-Chat" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| # Game Constants | |
| GRID_SIZE = 12 # Smaller grid for performance | |
| CELL_SIZE = 40 | |
| COLORS = { | |
| 'background': 'white', | |
| 'grid': 'lightgray', | |
| 'snake': 'red', | |
| 'agent': 'blue', | |
| 'obstacle': 'gray' | |
| } | |
| class GameState: | |
| def __init__(self): | |
| self.snake = [6, 6] # Center | |
| self.agents = [[2, 2], [9, 9], [2, 9]] | |
| self.obstacles = [[4, 4], [7, 7], [4, 7]] | |
| self.scores = {'snake': 0, 'agents': 0} | |
| self.history = [] | |
| def get_agent_state(self, agent_idx): | |
| return { | |
| 'position': self.agents[agent_idx], | |
| 'snake_pos': self.snake, | |
| 'other_agents': [pos for i, pos in enumerate(self.agents) if i != agent_idx], | |
| 'obstacles': self.obstacles | |
| } | |
| game = GameState() | |
| def get_model_decision(role, state): | |
| """Get next move from Qwen model.""" | |
| if role == "snake": | |
| prompt = f"You are a predator trying to catch prey. Your position is {state['position']}, prey positions are {state['other_agents']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word." | |
| else: | |
| prompt = f"You are prey avoiding a predator. Your position is {state['position']}, predator position is {state['snake_pos']}. Choose one move from: UP, DOWN, LEFT, RIGHT, STAY. Just output the move word." | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=10, | |
| temperature=0.7, | |
| do_sample=True | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract move from response | |
| moves = ["UP", "DOWN", "LEFT", "RIGHT", "STAY"] | |
| for move in moves: | |
| if move in response.upper(): | |
| return move | |
| return "STAY" | |
| def apply_move(position, move): | |
| """Apply move while respecting grid boundaries.""" | |
| x, y = position.copy() | |
| if move == "UP" and y > 0: | |
| y -= 1 | |
| elif move == "DOWN" and y < GRID_SIZE - 1: | |
| y += 1 | |
| elif move == "LEFT" and x > 0: | |
| x -= 1 | |
| elif move == "RIGHT" and x < GRID_SIZE - 1: | |
| x += 1 | |
| return [x, y] | |
| def create_game_image(): | |
| """Create game visualization.""" | |
| img = Image.new("RGB", (GRID_SIZE * CELL_SIZE, GRID_SIZE * CELL_SIZE), COLORS['background']) | |
| draw = ImageDraw.Draw(img) | |
| # Draw grid | |
| for i in range(GRID_SIZE + 1): | |
| draw.line([(i * CELL_SIZE, 0), (i * CELL_SIZE, GRID_SIZE * CELL_SIZE)], fill=COLORS['grid']) | |
| draw.line([(0, i * CELL_SIZE), (GRID_SIZE * CELL_SIZE, i * CELL_SIZE)], fill=COLORS['grid']) | |
| # Draw obstacles | |
| for pos in game.obstacles: | |
| draw.rectangle([ | |
| pos[0] * CELL_SIZE, pos[1] * CELL_SIZE, | |
| (pos[0] + 1) * CELL_SIZE, (pos[1] + 1) * CELL_SIZE | |
| ], fill=COLORS['obstacle']) | |
| # Draw agents | |
| for pos in game.agents: | |
| center = ((pos[0] + 0.5) * CELL_SIZE, (pos[1] + 0.5) * CELL_SIZE) | |
| radius = CELL_SIZE // 3 | |
| draw.ellipse([ | |
| center[0] - radius, center[1] - radius, | |
| center[0] + radius, center[1] + radius | |
| ], fill=COLORS['agent']) | |
| # Draw snake | |
| center = ((game.snake[0] + 0.5) * CELL_SIZE, (game.snake[1] + 0.5) * CELL_SIZE) | |
| radius = CELL_SIZE // 3 | |
| draw.ellipse([ | |
| center[0] - radius, center[1] - radius, | |
| center[0] + radius, center[1] + radius | |
| ], fill=COLORS['snake']) | |
| # Add scores | |
| draw.text((10, 10), f"Snake: {game.scores['snake']} | Agents: {game.scores['agents']}", fill="black") | |
| # Convert to bytes | |
| img_byte_arr = io.BytesIO() | |
| img.save(img_byte_arr, format='PNG') | |
| img_byte_arr.seek(0) | |
| return img_byte_arr | |
| def update_game(): | |
| """Update game state for one turn.""" | |
| # Snake's turn | |
| snake_state = {'position': game.snake, 'other_agents': game.agents} | |
| snake_move = get_model_decision('snake', snake_state) | |
| new_pos = apply_move(game.snake, snake_move) | |
| if new_pos not in game.obstacles: | |
| game.snake = new_pos | |
| # Agents' turns | |
| for i in range(len(game.agents)): | |
| agent_state = game.get_agent_state(i) | |
| agent_move = get_model_decision('agent', agent_state) | |
| new_pos = apply_move(game.agents[i], agent_move) | |
| if new_pos not in game.obstacles: | |
| game.agents[i] = new_pos | |
| # Check captures | |
| for i, agent_pos in enumerate(game.agents): | |
| if agent_pos == game.snake: | |
| game.scores['snake'] += 1 | |
| # Respawn agent | |
| while True: | |
| new_pos = [random.randint(0, GRID_SIZE - 1), random.randint(0, GRID_SIZE - 1)] | |
| if new_pos not in game.obstacles and new_pos != game.snake: | |
| game.agents[i] = new_pos | |
| break | |
| def game_loop(): | |
| """Main game loop.""" | |
| while True: | |
| update_game() | |
| img_bytes = create_game_image() | |
| socketio.emit('game_update', { | |
| 'image': img_bytes.getvalue().hex(), | |
| 'scores': game.scores | |
| }) | |
| time.sleep(1.0) # Slower updates to reduce resource usage | |
| def index(): | |
| return render_template('index.html') | |
| def handle_connect(): | |
| print('Client connected') | |
| if __name__ == '__main__': | |
| threading.Thread(target=game_loop, daemon=True).start() | |
| socketio.run(app, host='0.0.0.0', port=7860) |