Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """app.py | |
| Automatically generated by Colab. | |
| Original file is located at | |
| https://colab.research.google.com/drive/1NU6NHjan4eF9IVHR549tKLRVNUQ_dBD7 | |
| """ | |
| import os | |
| import torch | |
| import numpy as np | |
| import requests | |
| import json | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| import torch.nn as nn | |
| # ---- Load env variables ---- | |
| OPENROUTER_KEY = os.getenv("OPENROUTER_KEY") | |
| if not OPENROUTER_KEY: | |
| raise ValueError("OPENROUTER_KEY not set in environment variables.") | |
| # ---- Blackjack Environment ---- | |
| import random | |
| class BlackjackEnv: | |
| def __init__(self): | |
| self.dealer = [] | |
| self.player = [] | |
| self.usable_ace_player = False | |
| def draw_card(self): | |
| return random.randint(1, 10) | |
| def sum_hand(self, hand): | |
| total = sum(hand) | |
| ace = 1 in hand | |
| if ace and total + 10 <= 21: | |
| return total + 10, True | |
| return total, False | |
| def reset(self): | |
| self.player = [self.draw_card(), self.draw_card()] | |
| self.dealer = [self.draw_card()] | |
| total, usable_ace = self.sum_hand(self.player) | |
| self.usable_ace_player = usable_ace | |
| return (self.dealer[0], total, int(usable_ace)) | |
| def step(self, action): | |
| if action == 1: | |
| self.player.append(self.draw_card()) | |
| total, usable_ace = self.sum_hand(self.player) | |
| if total > 21: | |
| return (self.dealer[0], total, int(usable_ace)), -1, True | |
| return (self.dealer[0], total, int(usable_ace)), 0, False | |
| else: | |
| dealer_hand = self.dealer + [self.draw_card()] | |
| dealer_total, _ = self.sum_hand(dealer_hand) | |
| player_total, _ = self.sum_hand(self.player) | |
| if dealer_total < player_total: | |
| return (self.dealer[0], player_total, int(self.usable_ace_player)), 1, True | |
| elif dealer_total > player_total: | |
| return (self.dealer[0], player_total, int(self.usable_ace_player)), -1, True | |
| else: | |
| return (self.dealer[0], player_total, int(self.usable_ace_player)), 0, True | |
| # ---- QNetwork ---- | |
| class QNetwork(nn.Module): | |
| def __init__(self, state_size=3, hidden_size=128, action_size=2): | |
| super(QNetwork, self).__init__() | |
| self.model = nn.Sequential( | |
| nn.Linear(state_size, hidden_size), | |
| nn.ReLU(), | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.ReLU(), | |
| nn.Linear(hidden_size, action_size) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| # ---- Load model ---- | |
| model = QNetwork() | |
| model_path = "qnetwork_blackjack_weights.pth" | |
| model.load_state_dict(torch.load(model_path)) | |
| model.eval() | |
| env = BlackjackEnv() | |
| # ---- LLM Explanation ---- | |
| def explain_action(state, action): | |
| prompt = f""" | |
| You are a blackjack strategy explainer. The player has a total of {state[1]}. | |
| The dealer is showing {state[0]}. Usable ace: {bool(state[2])}. | |
| The DQN model chose to {'Hit' if action == 1 else 'Stick'}. | |
| Explain why this action makes sense in 2-3 sentences. | |
| """ | |
| headers = { | |
| "Authorization": f"Bearer {OPENROUTER_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": "mistralai/mistral-7b-instruct", | |
| "messages": [ | |
| {"role": "system", "content": "You explain blackjack strategies clearly."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| } | |
| try: | |
| response = requests.post("https://openrouter.ai/api/v1/chat/completions", | |
| headers=headers, data=json.dumps(data)) | |
| if response.status_code == 200: | |
| return response.json()['choices'][0]['message']['content'] | |
| return f"LLM error: {response.status_code} - {response.text}" | |
| except Exception as e: | |
| return f"LLM call failed: {str(e)}" | |
| # ---- Gradio App ---- | |
| def play_hand(): | |
| state = env.reset() | |
| state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0) | |
| with torch.no_grad(): | |
| q_values = model(state_tensor) | |
| action = torch.argmax(q_values).item() | |
| explanation = explain_action(state, action) | |
| action_name = "Hit" if action == 1 else "Stick" | |
| dealer_card, player_sum, usable_ace = state | |
| return [ | |
| str(player_sum), | |
| str(dealer_card), | |
| str(bool(usable_ace)), | |
| action_name, | |
| str(q_values.numpy().tolist()), | |
| explanation | |
| ] | |
| demo = gr.Interface( | |
| fn=play_hand, | |
| inputs=[], | |
| outputs=[ | |
| gr.Textbox(label="Player Sum"), | |
| gr.Textbox(label="Dealer Card"), | |
| gr.Textbox(label="Usable Ace"), | |
| gr.Textbox(label="DQN Action"), | |
| gr.Textbox(label="Q-values"), | |
| gr.Textbox(label="LLM Explanation") | |
| ], | |
| title="🧠 Blackjack Tutor: DQN + LLM", | |
| description="Play a hand of blackjack. See how a Deep Q Network plays, and get a natural language explanation from Mistral-7B via OpenRouter." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |
| import os | |
| print(os.listdir()) |