File size: 393 Bytes
51be264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import fire
import pandas as pd
def main(file_name: str):
df = pd.read_csv(file_name)
print("\n### All Average ###")
print(df.groupby("adapter_name").aggregate({"acc_score": "mean"}))
print("\n### Category Average ###")
print(
df.groupby(["adapter_name", "mmlu_categories"]).aggregate({"acc_score": "mean"})
)
if __name__ == "__main__":
fire.Fire(main)
|