Jabuszko commited on
Commit
a58bbc0
·
verified ·
1 Parent(s): e8e26aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -12,7 +12,7 @@ from transformers import AutoTokenizer, AutoModel
12
 
13
  ACTIONS = ["TRIP", "GITHUB", "MAIL"]
14
  NUM_ACTIONS = len(ACTIONS)
15
- DATASET_PATH = os.path.join(os.path.dirname(__file__), "dataset.jsonl")
16
 
17
  # Confidence threshold - below this returns NONE
18
  CONFIDENCE_THRESHOLD = 0.6
@@ -230,16 +230,15 @@ class RLAgent:
230
 
231
 
232
  def load_dataset():
233
- """Load and parse the dataset."""
234
- data = []
235
 
236
- with open(DATASET_PATH, "r") as f:
237
- for line in f:
238
- item = json.loads(line)
239
- user_msg = item["messages"][1]["content"]
240
- label = item["messages"][2]["content"]
241
- if label in ACTIONS:
242
- data.append((user_msg, ACTIONS.index(label)))
243
 
244
  random.shuffle(data)
245
  return data
 
12
 
13
  ACTIONS = ["TRIP", "GITHUB", "MAIL"]
14
  NUM_ACTIONS = len(ACTIONS)
15
+ DATASET_PATH = "iteratehack/code19-dataset"
16
 
17
  # Confidence threshold - below this returns NONE
18
  CONFIDENCE_THRESHOLD = 0.6
 
230
 
231
 
232
  def load_dataset():
233
+ """Load dataset from Hugging Face Datasets."""
234
+ dataset = load_dataset(HF_DATASET, split=HF_SPLIT)
235
 
236
+ data = []
237
+ for item in dataset:
238
+ user_msg = item["messages"][1]["content"]
239
+ label = item["messages"][2]["content"]
240
+ if label in ACTIONS:
241
+ data.append((user_msg, ACTIONS.index(label)))
 
242
 
243
  random.shuffle(data)
244
  return data