|
|
import os |
|
|
import random |
|
|
import csv |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import gradio as gr |
|
|
from tensorflow.keras.models import Sequential, load_model |
|
|
from tensorflow.keras.layers import LSTM, Dense, Embedding |
|
|
from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
|
from tensorflow.keras.preprocessing.text import Tokenizer |
|
|
|
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|
|
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" |
|
|
|
|
|
|
|
|
tf.get_logger().setLevel('ERROR') |
|
|
|
|
|
|
|
|
csv_filename = "game_moves.csv" |
|
|
model_filename = "lstm_model.h5" |
|
|
|
|
|
|
|
|
choices = {'rock': 0, 'paper': 1, 'scissors': 2} |
|
|
rev_choices = {0: 'rock', 1: 'paper', 2: 'scissors'} |
|
|
|
|
|
|
|
|
if not os.path.exists(csv_filename): |
|
|
with open(csv_filename, mode='w', newline='') as file: |
|
|
writer = csv.writer(file) |
|
|
writer.writerow(["Player Choice", "Computer Choice", "Result"]) |
|
|
|
|
|
def load_data(): |
|
|
""" Loads past player moves from CSV file. """ |
|
|
try: |
|
|
with open(csv_filename, mode="r") as file: |
|
|
reader = csv.reader(file) |
|
|
next(reader) |
|
|
return [row[0] for row in reader if row] |
|
|
except FileNotFoundError: |
|
|
return [] |
|
|
|
|
|
def train_lstm_model(data): |
|
|
""" Trains an LSTM model to predict the player's next move. """ |
|
|
if len(data) < 6: |
|
|
return None |
|
|
|
|
|
|
|
|
tokenizer = Tokenizer(num_words=3) |
|
|
tokenizer.fit_on_texts(["rock", "paper", "scissors"]) |
|
|
sequences = tokenizer.texts_to_sequences(data) |
|
|
|
|
|
|
|
|
X, y = [], [] |
|
|
for i in range(len(sequences) - 5): |
|
|
X.append(sequences[i:i+5]) |
|
|
y.append(sequences[i+5][0] if sequences[i+5] else 0) |
|
|
|
|
|
if len(X) == 0: |
|
|
return None |
|
|
|
|
|
X = pad_sequences(X, maxlen=5) |
|
|
y = np.array(y) |
|
|
|
|
|
model = Sequential([ |
|
|
Embedding(input_dim=4, output_dim=10, input_length=5), |
|
|
LSTM(30, return_sequences=False), |
|
|
Dense(3, activation="softmax") |
|
|
]) |
|
|
|
|
|
model.compile(loss="sparse_categorical_crossentropy", |
|
|
optimizer="adam", |
|
|
metrics=["accuracy"]) |
|
|
|
|
|
model.fit(X, y, epochs=10, batch_size=1, verbose=0) |
|
|
model.save(model_filename) |
|
|
|
|
|
return model |
|
|
|
|
|
def get_computer_choice(model, past_moves): |
|
|
""" Predicts player's next move and counteracts it. """ |
|
|
if len(past_moves) < 5 or model is None: |
|
|
return random.choice(["rock", "paper", "scissors"]) |
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = Tokenizer(num_words=3) |
|
|
tokenizer.fit_on_texts(["rock", "paper", "scissors"]) |
|
|
sequences = tokenizer.texts_to_sequences(past_moves[-5:]) |
|
|
|
|
|
if len(sequences) < 5: |
|
|
return random.choice(["rock", "paper", "scissors"]) |
|
|
|
|
|
sequence = pad_sequences([sequences], maxlen=5) |
|
|
|
|
|
prediction = model.predict(sequence, verbose=0) |
|
|
predicted_choice = rev_choices[np.argmax(prediction)] |
|
|
|
|
|
|
|
|
counter_choices = {'rock': 'paper', 'paper': 'scissors', 'scissors': 'rock'} |
|
|
return counter_choices[predicted_choice] |
|
|
except: |
|
|
return random.choice(["rock", "paper", "scissors"]) |
|
|
|
|
|
def get_winner(player, computer): |
|
|
""" Determines the winner of the game. """ |
|
|
if player == computer: |
|
|
return "It's a tie!" |
|
|
elif (player == "rock" and computer == "scissors") or \ |
|
|
(player == "scissors" and computer == "paper") or \ |
|
|
(player == "paper" and computer == "rock"): |
|
|
return "You win!" |
|
|
else: |
|
|
return "Computer wins!" |
|
|
|
|
|
def save_move(player, computer, result): |
|
|
""" Saves game move to CSV file. """ |
|
|
with open(csv_filename, mode="a", newline="") as file: |
|
|
writer = csv.writer(file) |
|
|
writer.writerow([player, computer, result]) |
|
|
|
|
|
|
|
|
past_moves = load_data() |
|
|
|
|
|
|
|
|
if os.path.exists(model_filename): |
|
|
try: |
|
|
model = load_model(model_filename) |
|
|
except: |
|
|
model = train_lstm_model(past_moves) if len(past_moves) >= 6 else None |
|
|
else: |
|
|
model = train_lstm_model(past_moves) if len(past_moves) >= 6 else None |
|
|
|
|
|
def play_game(player_choice): |
|
|
""" Handles the game logic and returns the result. """ |
|
|
global past_moves, model |
|
|
|
|
|
if player_choice not in choices: |
|
|
return "Invalid choice. Choose rock, paper, or scissors." |
|
|
|
|
|
|
|
|
if model is None: |
|
|
computer_choice = random.choice(["rock", "paper", "scissors"]) |
|
|
else: |
|
|
computer_choice = get_computer_choice(model, past_moves) |
|
|
|
|
|
result = get_winner(player_choice, computer_choice) |
|
|
|
|
|
|
|
|
save_move(player_choice, computer_choice, result) |
|
|
past_moves.append(player_choice) |
|
|
|
|
|
|
|
|
if len(past_moves) >= 6 and len(past_moves) % 10 == 0: |
|
|
model = train_lstm_model(past_moves) |
|
|
|
|
|
return f"**Your choice:** {player_choice}\n\n" \ |
|
|
f"**Computer choice:** {computer_choice}\n\n" \ |
|
|
f"**Result:** {result}\n\n" \ |
|
|
f"*Total games played: {len(past_moves)}*" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Rock Paper Scissors AI", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 🪨 📄 ✂️ Rock Paper Scissors AI") |
|
|
gr.Markdown("Play against an AI that learns from your moves and tries to beat you!") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
move_input = gr.Radio( |
|
|
choices=["rock", "paper", "scissors"], |
|
|
label="Choose your move", |
|
|
value="rock" |
|
|
) |
|
|
submit_btn = gr.Button("Play!", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
output = gr.Markdown("## Game will start here...") |
|
|
|
|
|
submit_btn.click( |
|
|
fn=play_game, |
|
|
inputs=move_input, |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
gr.Markdown("### How it works:") |
|
|
gr.Markdown(""" |
|
|
1. The AI uses an LSTM neural network to learn from your move patterns |
|
|
2. It predicts your next move based on your last 5 moves |
|
|
3. It counters your predicted move to try to win |
|
|
4. The model improves as you play more games |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(debug=False, show_error=True) |