Spaces:
Build error
Build error
| """ | |
| SPIRAL: Interactive Reasoning Game Simulator | |
| Main Gradio application for the SPIRAL demo on Hugging Face Spaces. | |
| """ | |
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import os | |
| import sys | |
| import traceback | |
| import yaml | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| import torch | |
| import spaces | |
| # Add src to path for imports | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| src_path = os.path.join(current_dir, 'src') | |
| sys.path.insert(0, src_path) | |
| print(f"๐ Current directory: {current_dir}") | |
| print(f"๐ Source path: {src_path}") | |
| print(f"๐ Python path: {sys.path[:3]}") # Show first 3 entries | |
| # Check if src directory exists | |
| if os.path.exists(src_path): | |
| print(f"โ Source directory exists: {src_path}") | |
| games_path = os.path.join(src_path, 'games') | |
| if os.path.exists(games_path): | |
| print(f"โ Games directory exists: {games_path}") | |
| print(f"๐ Games directory contents: {os.listdir(games_path)}") | |
| else: | |
| print(f"โ Games directory not found: {games_path}") | |
| else: | |
| print(f"โ Source directory not found: {src_path}") | |
| # Try multiple import approaches | |
| GAMES_AVAILABLE = False | |
| tictactoe_env = None | |
| kuhn_env = None | |
| try: | |
| # Method 1: Direct import from games module | |
| print("๐ Attempting Method 1: Direct import from games") | |
| from games import TicTacToeEnv, KuhnPokerEnv | |
| print("โ Method 1 successful: Imported from games module") | |
| GAMES_AVAILABLE = True | |
| except ImportError as e: | |
| print(f"โ Method 1 failed: {e}") | |
| try: | |
| # Method 2: Import from src.games | |
| print("๐ Attempting Method 2: Import from src.games") | |
| from src.games import TicTacToeEnv, KuhnPokerEnv | |
| print("โ Method 2 successful: Imported from src.games") | |
| GAMES_AVAILABLE = True | |
| except ImportError as e: | |
| print(f"โ Method 2 failed: {e}") | |
| try: | |
| # Method 3: Direct file imports | |
| print("๐ Attempting Method 3: Direct file imports") | |
| sys.path.insert(0, games_path) | |
| from tictactoe import TicTacToeEnv | |
| from kuhn_poker import KuhnPokerEnv | |
| print("โ Method 3 successful: Direct file imports") | |
| GAMES_AVAILABLE = True | |
| except Exception as e: | |
| print(f"โ Method 3 failed: {e}") | |
| print("๐ Full traceback:", traceback.format_exc()) | |
| if GAMES_AVAILABLE: | |
| print("๐ฎ Game modules successfully imported!") | |
| try: | |
| # Test instantiation | |
| tictactoe_env = TicTacToeEnv() | |
| # kuhn_env = KuhnPokerEnv() # No longer needed | |
| print("โ Game environment created successfully") | |
| except Exception as e: | |
| print(f"โ Error creating game environment: {e}") | |
| print("๐ Full traceback:", traceback.format_exc()) | |
| GAMES_AVAILABLE = False | |
| else: | |
| print("โ All import methods failed - using fallback interface") | |
| # Initialize model and tokenizer as global variables | |
| model = None | |
| tokenizer = None | |
| def generate_reasoning(prompt): | |
| """Generate reasoning trace using Qwen model.""" | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| return "Error: Model not loaded. Please wait for the GPU to be ready." | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate(**inputs, max_length=150, do_sample=True, temperature=0.7) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def create_interface(): | |
| """Create the main Gradio interface.""" | |
| # Custom CSS to style the TicTacToe board | |
| css = """ | |
| .ttt-board { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| max-width: 300px; | |
| margin: 0 auto; | |
| } | |
| .ttt-board > div { | |
| display: flex; | |
| flex-direction: row; | |
| justify-content: center; | |
| gap: 8px; | |
| margin: 4px 0; | |
| } | |
| .ttt-board button { | |
| width: 80px !important; | |
| height: 80px !important; | |
| min-width: 80px !important; | |
| min-height: 80px !important; | |
| max-width: 80px !important; | |
| max-height: 80px !important; | |
| font-size: 24px !important; | |
| font-weight: bold !important; | |
| border: 2px solid #374151 !important; | |
| border-radius: 8px !important; | |
| background: #1f2937 !important; | |
| color: white !important; | |
| display: flex !important; | |
| align-items: center !important; | |
| justify-content: center !important; | |
| } | |
| .ttt-board button:hover { | |
| background: #374151 !important; | |
| border-color: #6b7280 !important; | |
| } | |
| .ttt-board button:disabled { | |
| opacity: 0.8 !important; | |
| cursor: not-allowed !important; | |
| } | |
| .ttt-stats { | |
| text-align: center !important; | |
| margin: 20px 0 !important; | |
| font-size: 16px !important; | |
| } | |
| .ttt-stats p { | |
| margin: 0 !important; | |
| color: #9ca3af !important; | |
| } | |
| """ | |
| with gr.Blocks(title="SPIRAL: Interactive Reasoning Game Simulator", theme=gr.themes.Soft(), css=css) as demo: | |
| gr.Markdown("# ๐ฎ SPIRAL: Interactive Reasoning Game Simulator") | |
| gr.Markdown("Play TicTacToe against an AI, see its step-by-step reasoning, and learn how it thinks!") | |
| if GAMES_AVAILABLE: | |
| def update_board_buttons(): | |
| """Create a list of gr.Button updates from the current board state.""" | |
| updates = [] | |
| for i in range(9): | |
| row, col = divmod(i, 3) | |
| cell = tictactoe_env.board[row, col] | |
| val = "" | |
| interactive = True | |
| if cell == 1: | |
| val = 'โ' | |
| interactive = False | |
| elif cell == -1: | |
| val = 'โญ' | |
| interactive = False | |
| if tictactoe_env.game_over: | |
| interactive = False | |
| updates.append(gr.Button(value=val, interactive=interactive)) | |
| return updates | |
| # TicTacToe specific functions (no longer need get_tictactoe_board_html) | |
| ttt_stats = gr.State({'wins': 0, 'losses': 0, 'draws': 0}) | |
| def minimax(board, player): | |
| """Minimax algorithm to find the best move.""" | |
| # Base cases | |
| winner = tictactoe_env._check_winner() | |
| if winner == 1: # Human wins | |
| return -10, None | |
| elif winner == -1: # AI wins | |
| return 10, None | |
| elif tictactoe_env._is_draw(): | |
| return 0, None | |
| best_move = None | |
| if player == -1: # AI is player -1 (O), maximizing player | |
| best_score = -float('inf') | |
| for move in tictactoe_env._get_valid_actions(): | |
| row, col = divmod(move, 3) | |
| board[row, col] = -1 | |
| score, _ = minimax(board.copy(), 1) | |
| board[row, col] = 0 # Undo move | |
| if score > best_score: | |
| best_score = score | |
| best_move = move | |
| else: # Human is player 1 (X), minimizing player | |
| best_score = float('inf') | |
| for move in tictactoe_env._get_valid_actions(): | |
| row, col = divmod(move, 3) | |
| board[row, col] = 1 | |
| score, _ = minimax(board.copy(), -1) | |
| board[row, col] = 0 # Undo move | |
| if score < best_score: | |
| best_score = score | |
| best_move = move | |
| return best_score, best_move | |
| def play_tictactoe(position, stats): | |
| """Play a TicTacToe move and yield updates for the button grid.""" | |
| if tictactoe_env.game_over: | |
| yield *update_board_buttons(), "Game is over! Click 'New Game' to start again.", "", stats | |
| return | |
| try: | |
| position = int(position) | |
| # Human move | |
| tictactoe_env.step(position) | |
| if tictactoe_env.game_over: | |
| winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "Draw" | |
| if winner == "You": stats['wins'] += 1 | |
| elif winner == "AI": stats['losses'] += 1 | |
| else: stats['draws'] += 1 | |
| yield *update_board_buttons(), f"Game Over! {winner} won!", "", stats | |
| return | |
| # Show "thinking" indicator | |
| yield *update_board_buttons(), "AI is thinking...", "๐ง ...", stats | |
| # AI move | |
| _, ai_action = minimax(tictactoe_env.board.copy(), -1) | |
| if ai_action is None: | |
| valid_actions = tictactoe_env._get_valid_actions() | |
| if not valid_actions: | |
| yield *update_board_buttons(), "Game is a draw!", "", stats | |
| return | |
| ai_action = random.choice(valid_actions) | |
| reasoning_prompt = f"In TicTacToe, the board is currently: {tictactoe_env.board.flatten().tolist()}. The human player (X) played position {position}. I am the AI (O). The available moves are {tictactoe_env._get_valid_actions()}. I have analyzed the game tree using minimax and determined the optimal move is {ai_action}. Explain my strategy." | |
| reasoning = generate_reasoning(reasoning_prompt) | |
| tictactoe_env.step(ai_action) | |
| if tictactoe_env.game_over: | |
| winner = "You" if tictactoe_env.winner == 1 else "AI" if tictactoe_env.winner == -1 else "Draw" | |
| if winner == "You": stats['wins'] += 1 | |
| elif winner == "AI": stats['losses'] += 1 | |
| else: stats['draws'] += 1 | |
| yield *update_board_buttons(), f"Game Over! {winner} won! AI played {ai_action}.", reasoning, stats | |
| else: | |
| yield *update_board_buttons(), f"AI played position {ai_action}. Your turn!", reasoning, stats | |
| except Exception as e: | |
| yield *update_board_buttons(), f"Error: {str(e)}", "", stats | |
| def reset_tictactoe(stats): | |
| """Reset TicTacToe game.""" | |
| tictactoe_env.reset() | |
| return *update_board_buttons(), "New game started! You are โ (X). Click a square to play.", "AI will show its reasoning here...", stats | |
| # Initialize the board on startup | |
| tictactoe_env.reset() | |
| # Simplified layout focusing only on TicTacToe | |
| with gr.Row(): | |
| gr.Markdown("### Play TicTacToe against AI") | |
| gr.Markdown("") # spacer | |
| ttt_reset_btn = gr.Button("๐ New Game", variant="secondary", size="sm") | |
| gr.Markdown("You are โ (X) and go first. Click on a square to make your move.") | |
| # Game board centered | |
| with gr.Column(elem_classes=["ttt-board"]): | |
| board_buttons = [] | |
| for i in range(3): | |
| with gr.Row(elem_classes=["ttt-row"]): | |
| for j in range(3): | |
| pos = i * 3 + j | |
| button = gr.Button("", elem_id=f"ttt-cell-{pos}", size="lg", value="") | |
| board_buttons.append(button) | |
| # Stats display centered below board | |
| with gr.Row(): | |
| ttt_stats_display = gr.Markdown(value="**Wins: 0 | Losses: 0 | Draws: 0**", elem_classes=["ttt-stats"]) | |
| ttt_message = gr.Textbox( | |
| label="Game Status", | |
| value="Choose a position to start!", | |
| lines=2, | |
| interactive=False | |
| ) | |
| ttt_reasoning = gr.Textbox( | |
| label="AI Reasoning", | |
| value="AI will explain its thought process here...", | |
| lines=3, | |
| interactive=False | |
| ) | |
| # Create a combined click handler | |
| def on_board_click(pos, stats): | |
| yield from play_tictactoe(pos, stats) | |
| for i in range(9): | |
| board_buttons[i].click( | |
| fn=on_board_click, | |
| inputs=[gr.State(i), ttt_stats], | |
| outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats] | |
| ) | |
| ttt_reset_btn.click( | |
| fn=reset_tictactoe, | |
| inputs=[ttt_stats], | |
| outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats] | |
| ) | |
| # Update stats display on changes | |
| ttt_stats.change( | |
| fn=lambda s: f"Wins: {s['wins']} | Losses: {s['losses']} | Draws: {s['draws']}", | |
| inputs=ttt_stats, | |
| outputs=ttt_stats_display | |
| ) | |
| # Initialize board display on load | |
| demo.load( | |
| fn=lambda stats: (*update_board_buttons(), "Game ready! You are โ (X). Click a square to play.", "AI will show its reasoning here...", stats), | |
| inputs=[ttt_stats], | |
| outputs=[*board_buttons, ttt_message, ttt_reasoning, ttt_stats] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("๐ง **This is a development preview.** Full SPIRAL training and reasoning capabilities will be added in the next update!") | |
| else: | |
| # Fallback interface when games don't load | |
| gr.Markdown("โ ๏ธ **Game modules could not be loaded.** Showing diagnostic information.") | |
| gr.Markdown("This usually happens when dependencies are still installing on HF Spaces.") | |
| # Show diagnostic info | |
| gr.Markdown("### ๐ Diagnostic Information:") | |
| gr.Markdown(f"- Current directory: `{current_dir}`") | |
| gr.Markdown(f"- Source path: `{src_path}`") | |
| gr.Markdown(f"- Source directory exists: `{os.path.exists(src_path)}`") | |
| if os.path.exists(src_path): | |
| games_path = os.path.join(src_path, 'games') | |
| gr.Markdown(f"- Games directory exists: `{os.path.exists(games_path)}`") | |
| if os.path.exists(games_path): | |
| gr.Markdown(f"- Games directory contents: `{os.listdir(games_path)}`") | |
| # Simple demo interface | |
| with gr.Row(): | |
| simple_input = gr.Textbox(label="Test Input", placeholder="Enter something...") | |
| simple_output = gr.Textbox(label="Output", interactive=False) | |
| def simple_echo(text): | |
| return f"Echo: {text} (Game modules will be available once dependencies install)" | |
| simple_input.submit(fn=simple_echo, inputs=[simple_input], outputs=[simple_output]) | |
| # About Tab (always available) | |
| with gr.TabItem("โน๏ธ About"): | |
| gr.Markdown(""" | |
| ### About SPIRAL | |
| This is a **demo version** of the SPIRAL methodology: *"Self-Play on Zero-Sum Games Incentivizes Reasoning via Multi-Agent Multi-Turn Reinforcement Learning."* | |
| **Current Features:** | |
| - ๐ฏ **TicTacToe**: Play against a random AI opponent | |
| - ๐ **Kuhn Poker**: Experience simplified poker gameplay | |
| - ๐ฎ **Interactive Games**: Real-time game state updates | |
| **Coming Soon:** | |
| - ๐ง **SPIRAL-trained AI**: Opponents trained via self-play | |
| - ๐ **Reasoning Traces**: See step-by-step AI decision-making | |
| - ๐ฌ **Transfer Learning**: Test AI reasoning on math problems | |
| - ๐ **Performance Metrics**: Track AI improvement over time | |
| **Game Rules:** | |
| **TicTacToe:** | |
| - 3x3 grid, get 3 in a row to win | |
| - You are X, AI is O | |
| - Numbers 0-8 represent board positions | |
| **Kuhn Poker:** | |
| - 3 cards: Jack (lowest), Queen, King (highest) | |
| - Each player gets 1 card, antes 1 chip | |
| - Actions: Check/Call, Bet (+1 chip), Fold | |
| - Higher card wins if both call/check | |
| **Technical Details:** | |
| - Built with Gymnasium environments | |
| - Gradio web interface | |
| - Ready for SPIRAL training integration | |
| """) | |
| gr.Markdown("**New in this version:** Visual boards, stats tracking, and transfer test stub!") | |
| if not GAMES_AVAILABLE: | |
| gr.Markdown("---") | |
| gr.Markdown("๐ **Dependencies are loading.** Check the diagnostic info above and refresh in a few minutes!") | |
| return demo | |
| def main(): | |
| """ | |
| Main function to load model, create interface, and launch the Gradio app. | |
| Wrapped with @spaces.GPU to allocate a GPU for this Space. | |
| """ | |
| global model, tokenizer | |
| print("๐ Starting main application...") | |
| print("Loading configuration...") | |
| with open('config.yaml', 'r') as f: | |
| config = yaml.safe_load(f) | |
| model_name = config['model']['name'] | |
| quantization_params = config['model'].get('quantization', {}) | |
| print(f"๐ฆ Model Name: {model_name}") | |
| print(f"โ๏ธ Quantization Params: {quantization_params}") | |
| # Create BitsAndBytesConfig if quantization is enabled | |
| if quantization_params and quantization_params.get('load_in_4bit'): | |
| print("๐ก 4-bit quantization enabled. Creating BitsAndBytesConfig...") | |
| compute_dtype_str = quantization_params.get("bnb_4bit_compute_dtype", "float16") | |
| if compute_dtype_str == "bfloat16": | |
| compute_dtype = torch.bfloat16 | |
| else: | |
| compute_dtype = torch.float16 # Default to float16 | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type=quantization_params.get("bnb_4bit_quant_type", "nf4"), | |
| bnb_4bit_compute_dtype=compute_dtype, | |
| bnb_4bit_use_double_quant=quantization_params.get("bnb_4bit_use_double_quant", True), | |
| ) | |
| # Using device_map="auto" is recommended for multi-GPU setups and large models | |
| print("๐ง Loading 4-bit quantized model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| quantization_config=bnb_config, | |
| device_map="auto" | |
| ) | |
| else: | |
| print("๐ง Loading model without quantization...") | |
| # Fallback for no quantization | |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
| print("โ๏ธ Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| print("โ Model and tokenizer loaded successfully.") | |
| print("๐จ Creating Gradio interface...") | |
| demo = create_interface() | |
| print("๐ Launching Gradio app...") | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |