Dylan-Kaneshiro commited on
Commit
3bd038f
·
verified ·
1 Parent(s): 624a315

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -2
app.py CHANGED
@@ -1,8 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import random
3
  import pandas as pd
4
- from helpers import *
5
- from mabwiser.mab import MAB, LearningPolicy
6
 
7
  # Load songs dataset
8
  #file_path = "/content/drive/My Drive/MIT/RealTime/songs_single_genre.csv"
 
1
+ from mabwiser.mab import MAB, LearningPolicy
2
+
3
+ def bandit_factory(bandit_type, arms):
4
+ if bandit_type == "Epsilon Greedy":
5
+ result = MAB(arms=arms,
6
+ learning_policy=LearningPolicy.EpsilonGreedy(epsilon=0.3),
7
+ seed=1234)
8
+ elif bandit_type == "UCB":
9
+ result = MAB(arms=arms,
10
+ learning_policy=LearningPolicy.UCB1(alpha=1),
11
+ seed=1234)
12
+ elif bandit_type == "Non-Stationary":
13
+ result = NSBandit(arms=arms, epsilon=0.3, alpha=0.2)
14
+ else:
15
+ raise ValueError("Invalid bandit type")
16
+
17
+ result.partial_fit(decisions=arms, rewards=[3]*len(arms))
18
+ return result
19
+
20
+ class NSBandit:
21
+ def __init__(self, arms, epsilon, alpha):
22
+ self.arms = arms
23
+ self.epsilon = epsilon
24
+ self.alpha = alpha
25
+ self.means = {arm: None for arm in arms}
26
+ self.t = 0
27
+
28
+ def partial_fit(self, decisions, rewards):
29
+ for arm, reward in zip(decisions, rewards):
30
+ if self.means[arm] is None:
31
+ self.means[arm] = reward
32
+ else:
33
+ self.means[arm] += self.alpha * (reward - self.means[arm])
34
+ self.t += 1
35
+
36
+
37
+ def predict(self):
38
+ nones = [t[0] for t in self.means.items() if t[1] is None]
39
+ if len(nones) > 0:
40
+ return random.choice(nones)
41
+
42
+ best = max(self.means, key=self.means.get)
43
+
44
+ if random.random() < self.epsilon:
45
+ return random.choice(list(set(self.arms) - {best}))
46
+ else:
47
+ return max(self.means, key=self.means.get)
48
+
49
  import gradio as gr
50
  import random
51
  import pandas as pd
 
 
52
 
53
  # Load songs dataset
54
  #file_path = "/content/drive/My Drive/MIT/RealTime/songs_single_genre.csv"