bstraehle commited on
Commit
72c76b2
·
verified ·
1 Parent(s): 4a3493d

Update utils/utils.py

Browse files
Files changed (1) hide show
  1. utils/utils.py +14 -25
utils/utils.py CHANGED
@@ -4,19 +4,19 @@ from docx import Document
4
  from pptx import Presentation
5
  from datasets import load_dataset
6
 
7
- QUESTION_TYPE_GAIA = "gaia"
8
- QUESTION_TYPE_HLE = "hle"
9
 
10
- QUESTION_FILE_PATH_GAIA = "files/gaia_validation.jsonl"
11
- QUESTION_FILE_PATH_HLE = "files/hle_validation.jsonl"
12
 
13
- def get_questions_from_file(question_type, level):
14
  file_path = ""
15
 
16
- if question_type == QUESTION_TYPE_GAIA:
17
- file_path = QUESTION_FILE_PATH_GAIA
18
- elif question_type == QUESTION_TYPE_HLE:
19
- file_path = QUESTION_FILE_PATH_HLE
20
 
21
  df = pd.read_json(file_path, lines=True)
22
 
@@ -29,13 +29,7 @@ def get_questions_from_file(question_type, level):
29
 
30
  return result
31
 
32
- def get_questions_from_dataset(question_type, level):
33
- # Extract dataset type from file path (e.g., "gaia" or "hle")
34
- basename = os.path.splitext(os.path.basename(file_path))[0]
35
- print(f"basename={basename}")
36
- dataset_type = basename.replace("_validation", "")
37
- print(f"basename={dataset_type}")
38
-
39
  # Get space ID from environment, defaulting to "bstraehle/gaia"
40
  space_id = os.environ.get("SPACE_ID", "bstraehle/gaia")
41
  # Extract username from space_id
@@ -50,11 +44,11 @@ def get_questions_from_dataset(question_type, level):
50
  df = dataset.to_pandas()
51
 
52
  # Filter by dataset type using the task_id prefix
53
- if dataset_type == "gaia":
54
- print(f"dataset_type={dataset_type}")
55
  df = df[df["task_id"].str.startswith("gaia-")]
56
- elif dataset_type == "hle":
57
- print(f"dataset_type={dataset_type}")
58
  df = df[df["task_id"].str.startswith("hle-")]
59
 
60
  # Filter by level if level > 0 (for GAIA benchmark)
@@ -63,11 +57,6 @@ def get_questions_from_dataset(question_type, level):
63
  df = df[df["Level"] == level]
64
 
65
  result=[]
66
-
67
- for _, row in df.iterrows():
68
- result.append([row["Question"], row["Final answer"], row["file_name"]])
69
-
70
- return result
71
 
72
  def is_ext(file_path, ext):
73
  return os.path.splitext(file_path)[1].lower() == ext.lower()
 
4
  from pptx import Presentation
5
  from datasets import load_dataset
6
 
7
+ DATASET_TYPE_GAIA = "gaia"
8
+ DATASET_TYPE_HLE = "hle"
9
 
10
+ DATASET_FILE_PATH_GAIA = "files/gaia_validation.jsonl"
11
+ DATASET_FILE_PATH_HLE = "files/hle_validation.jsonl"
12
 
13
+ def get_questions_from_file(dataset_type, level):
14
  file_path = ""
15
 
16
+ if dataset_type == DATASET_TYPE_GAIA:
17
+ file_path = DATASET_FILE_PATH_GAIA
18
+ elif dataset_type == DATASET_TYPE_HLE:
19
+ file_path = DATASET_FILE_PATH_HLE
20
 
21
  df = pd.read_json(file_path, lines=True)
22
 
 
29
 
30
  return result
31
 
32
+ def get_questions_from_dataset(dataset_type, level):
 
 
 
 
 
 
33
  # Get space ID from environment, defaulting to "bstraehle/gaia"
34
  space_id = os.environ.get("SPACE_ID", "bstraehle/gaia")
35
  # Extract username from space_id
 
44
  df = dataset.to_pandas()
45
 
46
  # Filter by dataset type using the task_id prefix
47
+ if dataset_type == DATASET_TYPE_GAIA:
48
+ print(f"filtering for dataset_type={dataset_type}")
49
  df = df[df["task_id"].str.startswith("gaia-")]
50
+ elif dataset_type == DATASET_TYPE_HLE:
51
+ print(f"filtering for dataset_type={dataset_type}")
52
  df = df[df["task_id"].str.startswith("hle-")]
53
 
54
  # Filter by level if level > 0 (for GAIA benchmark)
 
57
  df = df[df["Level"] == level]
58
 
59
  result=[]
 
 
 
 
 
60
 
61
  def is_ext(file_path, ext):
62
  return os.path.splitext(file_path)[1].lower() == ext.lower()