franlucc commited on
Commit
1a9dcdb
·
1 Parent(s): a72f911

add missing utils

Browse files
Files changed (1) hide show
  1. utils.py +48 -0
utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from hashlib import sha256
3
+ from typing import List, Tuple, Dict, Any
4
+ import math
5
+ import re
6
+
7
+ EXTRACTION_PROMPT = "All attempted answers, correct and incorrect"
8
+
9
+ def regex_compare(a: str, b: str) -> bool:
10
+ """
11
+ Compare all alphanum chars in a and b
12
+ """
13
+ a_chars = "".join(re.findall(r'\w', a))
14
+ b_chars = "".join(re.findall(r'\w', b))
15
+ return a_chars == b_chars or a_chars in b_chars
16
+
17
+ def print_info(db_connection):
18
+ tables = db_connection.execute("SHOW TABLES").fetchall()
19
+ # Iterate over each table and print its name and columns
20
+ for table in tables:
21
+ table_name = table[0]
22
+ print(f"Table: {table_name}")
23
+
24
+ # Get the columns for this table
25
+ columns = db_connection.execute(f"DESCRIBE {table_name}").fetchall()
26
+
27
+ # Print the column details
28
+ for column in columns:
29
+ print(f" - {column[0]} ({column[1]})") # column[0] is the column name, column[1] is the data type
30
+
31
+ print() # Add a blank line between tables for readability
32
+
33
+ def query_format_models(models: List[str]) -> str:
34
+ """
35
+ Format model names for the SQL query `WHERE <this_model> IN <models>
36
+ """
37
+ return "('" + "','".join(["completions-"+m for m in models]) + "')"
38
+
39
+ def get_completions(db_connector, query: str, **query_kwargs) -> pd.DataFrame:
40
+ """
41
+ If model has multiple completions, use only first.
42
+ """
43
+ df = db_connector.sql(query.format(**query_kwargs)).df()
44
+ df = df.groupby(["prompt_id", "model", "solution", "prompt"]).agg({"completion":"first"}).reset_index()
45
+ return df
46
+
47
+ def sha256_hash(text: str) -> str:
48
+ return sha256(bytes(text, "utf-8")).hexdigest()