GenAIDevTOProd commited on
Commit
fa91c7c
·
verified ·
1 Parent(s): d5bd75d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """app.py
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1NU6NHjan4eF9IVHR549tKLRVNUQ_dBD7
8
+ """
9
+
10
+ import os
11
+ import torch
12
+ import numpy as np
13
+ import requests
14
+ import json
15
+ import gradio as gr
16
+ from dotenv import load_dotenv
17
+ import torch.nn as nn
18
+
19
+ # ---- Load env variables ----
20
+ # load_dotenv()
21
+ OPENROUTER_KEY = os.getenv("OPENROUTER_KEY")
22
+ if not OPENROUTER_KEY:
23
+ raise ValueError("OPENROUTER_KEY not set in environment variables.")
24
+
25
+ # ---- Blackjack Environment ----
26
+ import random
27
+
28
+ class BlackjackEnv:
29
+ def __init__(self):
30
+ self.dealer = []
31
+ self.player = []
32
+ self.usable_ace_player = False
33
+
34
+ def draw_card(self):
35
+ return random.randint(1, 10)
36
+
37
+ def sum_hand(self, hand):
38
+ total = sum(hand)
39
+ ace = 1 in hand
40
+ if ace and total + 10 <= 21:
41
+ return total + 10, True
42
+ return total, False
43
+
44
+ def reset(self):
45
+ self.player = [self.draw_card(), self.draw_card()]
46
+ self.dealer = [self.draw_card()]
47
+ total, usable_ace = self.sum_hand(self.player)
48
+ self.usable_ace_player = usable_ace
49
+ return (self.dealer[0], total, int(usable_ace))
50
+
51
+ def step(self, action):
52
+ if action == 1:
53
+ self.player.append(self.draw_card())
54
+ total, usable_ace = self.sum_hand(self.player)
55
+ if total > 21:
56
+ return (self.dealer[0], total, int(usable_ace)), -1, True
57
+ return (self.dealer[0], total, int(usable_ace)), 0, False
58
+ else:
59
+ dealer_hand = self.dealer + [self.draw_card()]
60
+ dealer_total, _ = self.sum_hand(dealer_hand)
61
+ player_total, _ = self.sum_hand(self.player)
62
+ if dealer_total < player_total:
63
+ return (self.dealer[0], player_total, int(self.usable_ace_player)), 1, True
64
+ elif dealer_total > player_total:
65
+ return (self.dealer[0], player_total, int(self.usable_ace_player)), -1, True
66
+ else:
67
+ return (self.dealer[0], player_total, int(self.usable_ace_player)), 0, True
68
+
69
+ # ---- QNetwork ----
70
+ class QNetwork(nn.Module):
71
+ def __init__(self, state_size=3, hidden_size=128, action_size=2):
72
+ super(QNetwork, self).__init__()
73
+ self.model = nn.Sequential(
74
+ nn.Linear(state_size, hidden_size),
75
+ nn.ReLU(),
76
+ nn.Linear(hidden_size, hidden_size),
77
+ nn.ReLU(),
78
+ nn.Linear(hidden_size, action_size)
79
+ )
80
+
81
+ def forward(self, x):
82
+ return self.model(x)
83
+
84
+ # ---- Load model ----
85
+ model = QNetwork()
86
+ model_path = "/content/sample_data/qnetwork_blackjack_weights.pth"
87
+ model.load_state_dict(torch.load(model_path))
88
+ model.eval()
89
+
90
+ env = BlackjackEnv()
91
+
92
+ # ---- LLM Explanation ----
93
+ def explain_action(state, action):
94
+ prompt = f"""
95
+ You are a blackjack strategy explainer. The player has a total of {state[1]}.
96
+ The dealer is showing {state[0]}. Usable ace: {bool(state[2])}.
97
+ The DQN model chose to {'Hit' if action == 1 else 'Stick'}.
98
+ Explain why this action makes sense in 2-3 sentences.
99
+ """
100
+
101
+ headers = {
102
+ "Authorization": f"Bearer {OPENROUTER_KEY}",
103
+ "Content-Type": "application/json"
104
+ }
105
+
106
+ data = {
107
+ "model": "mistralai/mistral-7b-instruct",
108
+ "messages": [
109
+ {"role": "system", "content": "You explain blackjack strategies clearly."},
110
+ {"role": "user", "content": prompt}
111
+ ]
112
+ }
113
+
114
+ try:
115
+ response = requests.post("https://openrouter.ai/api/v1/chat/completions",
116
+ headers=headers, data=json.dumps(data))
117
+ if response.status_code == 200:
118
+ return response.json()['choices'][0]['message']['content']
119
+ return f"LLM error: {response.status_code} - {response.text}"
120
+ except Exception as e:
121
+ return f"LLM call failed: {str(e)}"
122
+
123
+ # ---- Gradio App ----
124
+ def play_hand():
125
+ state = env.reset()
126
+ state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
127
+ with torch.no_grad():
128
+ q_values = model(state_tensor)
129
+ action = torch.argmax(q_values).item()
130
+
131
+ explanation = explain_action(state, action)
132
+ action_name = "Hit" if action == 1 else "Stick"
133
+ dealer_card, player_sum, usable_ace = state
134
+
135
+ return [
136
+ str(player_sum),
137
+ str(dealer_card),
138
+ str(bool(usable_ace)),
139
+ action_name,
140
+ str(q_values.numpy().tolist()),
141
+ explanation
142
+ ]
143
+
144
+ demo = gr.Interface(
145
+ fn=play_hand,
146
+ inputs=[],
147
+ outputs=[
148
+ gr.Textbox(label="Player Sum"),
149
+ gr.Textbox(label="Dealer Card"),
150
+ gr.Textbox(label="Usable Ace"),
151
+ gr.Textbox(label="DQN Action"),
152
+ gr.Textbox(label="Q-values"),
153
+ gr.Textbox(label="LLM Explanation")
154
+ ],
155
+ title="🧠 Blackjack Tutor: DQN + LLM",
156
+ description="Play a hand of blackjack. See how a Deep Q Network plays, and get a natural language explanation from Mistral-7B via OpenRouter."
157
+ )
158
+
159
+ if __name__ == "__main__":
160
+ demo.launch()
161
+
162
+ import os
163
+ print(os.listdir())